Skip to content

Commit 7c7f6fc

Browse files
committed
Better input inference in mapping + default in multirate = same-scale scalar when the temporal cache path misses in multirate execution
Added logic to read mapped_variables for each input. If mapping gives one unambiguous source (scale,var), inference now uses it directly. If mapping exists but is multi-source/self-mapped, generic same-name inference is skipped (to avoid wrong auto-bindings). Kept strict ambiguity error when truly ambiguous and no mapping hint resolves it.
1 parent 0a4e54c commit 7c7f6fc

3 files changed

Lines changed: 235 additions & 6 deletions

File tree

src/mtg/model_spec_inference.jl

Lines changed: 140 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -368,6 +368,112 @@ function _default_policy_for_inferred_binding(model_specs, source_scale::Symbol,
368368
)
369369
end
370370

371+
function _mapped_source_scales_for_input(spec::ModelSpec, input_var::Symbol)
372+
mapped = mapped_variables_(spec)
373+
isempty(mapped) && return Set{Symbol}()
374+
375+
scales = Set{Symbol}()
376+
for mv in mapped
377+
mapped_input = first(mv)
378+
mapped_input = mapped_input isa PreviousTimeStep ? mapped_input.variable : mapped_input
379+
mapped_input == input_var || continue
380+
381+
rhs = last(mv)
382+
if rhs isa Pair{Symbol,Symbol}
383+
src_scale = first(rhs)
384+
src_scale == Symbol("") || push!(scales, src_scale)
385+
elseif rhs isa AbstractVector
386+
for item in rhs
387+
item isa Pair{Symbol,Symbol} || continue
388+
src_scale = first(item)
389+
src_scale == Symbol("") || push!(scales, src_scale)
390+
end
391+
end
392+
end
393+
394+
return scales
395+
end
396+
397+
function _input_has_multiscale_mapping(spec::ModelSpec, input_var::Symbol)
398+
mapped = mapped_variables_(spec)
399+
isempty(mapped) && return false
400+
401+
for mv in mapped
402+
mapped_input = first(mv)
403+
mapped_input = mapped_input isa PreviousTimeStep ? mapped_input.variable : mapped_input
404+
mapped_input == input_var && return true
405+
end
406+
407+
return false
408+
end
409+
410+
function _mapped_sources_for_input(spec::ModelSpec, input_var::Symbol)
411+
mapped = mapped_variables_(spec)
412+
isempty(mapped) && return Pair{Symbol,Symbol}[]
413+
414+
sources = Pair{Symbol,Symbol}[]
415+
for mv in mapped
416+
mapped_input = first(mv)
417+
mapped_input = mapped_input isa PreviousTimeStep ? mapped_input.variable : mapped_input
418+
mapped_input == input_var || continue
419+
420+
rhs = last(mv)
421+
if rhs isa Pair{Symbol,Symbol}
422+
push!(sources, rhs)
423+
elseif rhs isa AbstractVector
424+
for item in rhs
425+
item isa Pair{Symbol,Symbol} || continue
426+
push!(sources, item)
427+
end
428+
end
429+
end
430+
431+
return sources
432+
end
433+
434+
function _infer_binding_from_multiscale_mapping(
435+
model_specs,
436+
scale::Symbol,
437+
process::Symbol,
438+
spec::ModelSpec,
439+
input_var::Symbol;
440+
active_processes_by_scale=nothing
441+
)
442+
has_mapping = _input_has_multiscale_mapping(spec, input_var)
443+
has_mapping || return nothing
444+
445+
mapped_sources = _mapped_sources_for_input(spec, input_var)
446+
# Mapping exists but does not point to another scale (self/same-scale aliasing):
447+
# avoid generic same-name inference in that case.
448+
filtered_sources = filter(s -> first(s) != Symbol(""), mapped_sources)
449+
isempty(filtered_sources) && return :skip
450+
451+
# Multi-source mapping (e.g. vectors from several scales) cannot be represented
452+
# as one `InputBindings` entry; keep binding unresolved and skip generic inference.
453+
length(filtered_sources) == 1 || return :skip
454+
455+
src = only(filtered_sources)
456+
src_scale = first(src)
457+
src_var = last(src)
458+
haskey(model_specs, src_scale) || return :skip
459+
460+
procs = Symbol[]
461+
for (src_process, src_spec) in pairs(model_specs[src_scale])
462+
if !isnothing(active_processes_by_scale)
463+
active = get(active_processes_by_scale, src_scale, Set{Symbol}())
464+
src_process in active || continue
465+
end
466+
src_var in keys(outputs_(model_(src_spec))) || continue
467+
_is_stream_only_output(src_spec, src_var) && continue
468+
push!(procs, src_process)
469+
end
470+
471+
length(procs) == 1 || return :skip
472+
src_process = only(procs)
473+
policy = _default_policy_for_inferred_binding(model_specs, src_scale, src_process, src_var)
474+
return (process=src_process, var=src_var, scale=src_scale, policy=policy)
475+
end
476+
371477
function _infer_input_binding_for_var(
372478
model_specs,
373479
scale::Symbol,
@@ -388,7 +494,7 @@ function _infer_input_binding_for_var(
388494
if length(same_scale) == 1
389495
c = only(same_scale)
390496
policy = _default_policy_for_inferred_binding(model_specs, c.scale, c.process, c.var)
391-
return (process=c.process, var=c.var, policy=policy)
497+
return (process=c.process, var=c.var, scale=c.scale, policy=policy)
392498
elseif length(same_scale) > 1
393499
error(
394500
"Ambiguous inferred producer for input `$(input_var)` in process `$(process)` at scale `$(scale)`. ",
@@ -415,9 +521,25 @@ function _infer_input_binding_for_var(
415521
policy = _default_policy_for_inferred_binding(model_specs, src_scale, proc, input_var)
416522
return (process=proc, var=input_var, scale=src_scale, policy=policy)
417523
end
418-
# Same process name appears at multiple scales (common in multiscale
419-
# mappings). Keep scale unresolved so runtime resolves through parent links.
420-
return (process=proc, var=input_var, policy=HoldLast())
524+
525+
# When multiscale mapping already declares a source scale for this
526+
# input, use it to disambiguate instead of forcing explicit bindings.
527+
consumer_spec = model_specs[scale][process]
528+
mapped_scales = _mapped_source_scales_for_input(consumer_spec, input_var)
529+
candidate_scales = Set(scales)
530+
hinted_scales = intersect(mapped_scales, candidate_scales)
531+
if length(hinted_scales) == 1
532+
src_scale = only(hinted_scales)
533+
policy = _default_policy_for_inferred_binding(model_specs, src_scale, proc, input_var)
534+
return (process=proc, var=input_var, scale=src_scale, policy=policy)
535+
end
536+
537+
error(
538+
"Ambiguous inferred producer for input `$(input_var)` in process `$(process)` at scale `$(scale)`. ",
539+
"Process `$(proc)` publishes this variable at multiple reachable scales: $(join(scales, ", ")). ",
540+
"Please provide explicit `InputBindings(...)` with `scale`, ",
541+
"or add a `MultiScaleModel(...)` mapping so the source scale is unambiguous."
542+
)
421543
end
422544

423545
error(
@@ -453,6 +575,20 @@ function _infer_input_bindings!(model_specs; scale_reachability=nothing, active_
453575

454576
for input_var in model_inputs
455577
input_var in keys(current_bindings) && continue
578+
mapped_binding = _infer_binding_from_multiscale_mapping(
579+
model_specs,
580+
scale,
581+
process,
582+
spec,
583+
input_var;
584+
active_processes_by_scale=active_processes_by_scale
585+
)
586+
if mapped_binding === :skip
587+
continue
588+
elseif !isnothing(mapped_binding)
589+
push!(inferred, input_var => mapped_binding)
590+
continue
591+
end
456592
inferred_binding = _infer_input_binding_for_var(
457593
model_specs,
458594
scale,

src/time/runtime/input_resolution.jl

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,15 @@ function _assign_input_value!(st::Status, input_var::Symbol, value)
156156
return nothing
157157
end
158158

159+
function _same_scale_status_value(source_statuses, target_node_id::Int, source_var::Symbol)
160+
for src_st in source_statuses
161+
node_id(src_st.node) == target_node_id || continue
162+
source_var in keys(src_st) || continue
163+
return src_st[source_var], true
164+
end
165+
return nothing, false
166+
end
167+
159168
"""
160169
_resolve_input_windowed(sim, node, st, input_var, source_scale, source_process, source_var, t_start, t_end, policy)
161170
@@ -210,6 +219,16 @@ function _resolve_input_windowed(
210219
return nothing
211220
end
212221

222+
# Same-scale scalar fallback: prefer the value attached to the consumer node
223+
# before scanning all source nodes (which can be ambiguous in dense scales).
224+
if source_scale == node.scale
225+
vv, found = _same_scale_status_value(source_statuses, consumer_node_id, source_var)
226+
if found
227+
_assign_input_value!(st, input_var, vv)
228+
return nothing
229+
end
230+
end
231+
213232
# Cross-scale scalar fallback: allow unique producer value at source scale.
214233
candidates = Any[]
215234
for src_st in source_statuses
@@ -357,6 +376,16 @@ function _resolve_input_holdlast(
357376
return nothing
358377
end
359378

379+
# Same-scale scalar fallback: prefer the value attached to the consumer node
380+
# before scanning all source nodes (which can be ambiguous in dense scales).
381+
if source_scale == node.scale
382+
vv, found = _same_scale_status_value(source_statuses, consumer_node_id, source_var)
383+
if found
384+
_assign_input_value!(st, input_var, vv)
385+
return nothing
386+
end
387+
end
388+
360389
# Cross-scale scalar fallback: allow unique producer value at source scale.
361390
candidates = Any[]
362391
for src_st in source_statuses

test/test-multirate-runtime.jl

Lines changed: 66 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,24 @@ function PlantSimEngine.run!(::MRZConsumerModel, models, status, meteo, constant
210210
status.ZZ = status.Z
211211
end
212212

213+
PlantSimEngine.@process "mrmultiscaletsource" verbose = false
214+
struct MRMultiScaleTSourceModel <: AbstractMrmultiscaletsourceModel
215+
tt::Float64
216+
end
217+
PlantSimEngine.inputs_(::MRMultiScaleTSourceModel) = NamedTuple()
218+
PlantSimEngine.outputs_(::MRMultiScaleTSourceModel) = (TT=-Inf,)
219+
function PlantSimEngine.run!(m::MRMultiScaleTSourceModel, models, status, meteo, constants=nothing, extra=nothing)
220+
status.TT = m.tt
221+
end
222+
223+
PlantSimEngine.@process "mrleafttconsumer" verbose = false
224+
struct MRLeafTTConsumerModel <: AbstractMrleafttconsumerModel end
225+
PlantSimEngine.inputs_(::MRLeafTTConsumerModel) = (TT=-Inf,)
226+
PlantSimEngine.outputs_(::MRLeafTTConsumerModel) = (TT_used=-Inf,)
227+
function PlantSimEngine.run!(::MRLeafTTConsumerModel, models, status, meteo, constants=nothing, extra=nothing)
228+
status.TT_used = status.TT
229+
end
230+
213231
PlantSimEngine.@process "mrmissinginputconsumer" verbose = false
214232
struct MRMissingInputConsumerModel <: AbstractMrmissinginputconsumerModel end
215233
PlantSimEngine.inputs_(::MRMissingInputConsumerModel) = (U=-Inf,)
@@ -1044,7 +1062,53 @@ PlantSimEngine.meteo_hint(::Type{<:MRMeteoHintConsumerModel}) = (
10441062
@test input_bindings(spec_lineage_infer).Z.process == :mrancestorsource
10451063
@test input_bindings(spec_lineage_infer).Z.scale == :Plant
10461064

1047-
# Expectation 24c: same-rate hard dependencies are ignored for auto bindings and canonical publisher checks.
1065+
# Expectation 24c: inferred bindings prefer same-scale producers when available.
1066+
mapping_same_scale_preferred = ModelMapping(
1067+
:Plant => (
1068+
MRMultiScaleTSourceModel(100.0),
1069+
),
1070+
:Internode => (
1071+
MRMultiScaleTSourceModel(10.0),
1072+
),
1073+
:Leaf => (
1074+
MRMultiScaleTSourceModel(1.0),
1075+
MRLeafTTConsumerModel(),
1076+
),
1077+
)
1078+
sim_same_scale_preferred = PlantSimEngine.GraphSimulation(
1079+
mtg,
1080+
mapping_same_scale_preferred,
1081+
nsteps=1,
1082+
check=true,
1083+
outputs=Dict(:Leaf => (:TT, :TT_used))
1084+
)
1085+
run!(sim_same_scale_preferred, meteo, executor=SequentialEx())
1086+
spec_same_scale_preferred = PlantSimEngine.get_model_specs(sim_same_scale_preferred)[:Leaf][:mrleafttconsumer]
1087+
@test input_bindings(spec_same_scale_preferred).TT.process == :mrmultiscaletsource
1088+
@test input_bindings(spec_same_scale_preferred).TT.scale == :Leaf
1089+
@test status(sim_same_scale_preferred)[:Leaf][1].TT_used == 1.0
1090+
1091+
# Expectation 24d: when no same-scale producer exists, repeated process names across scales require explicit scale mapping.
1092+
mapping_cross_scale_same_process_ambiguous = ModelMapping(
1093+
:Plant => (
1094+
MRMultiScaleTSourceModel(100.0),
1095+
),
1096+
:Internode => (
1097+
MRMultiScaleTSourceModel(10.0),
1098+
),
1099+
:Leaf => (
1100+
MRLeafTTConsumerModel(),
1101+
),
1102+
)
1103+
@test_throws "Ambiguous inferred producer for input `TT`" PlantSimEngine.GraphSimulation(
1104+
mtg,
1105+
mapping_cross_scale_same_process_ambiguous,
1106+
nsteps=1,
1107+
check=true,
1108+
outputs=Dict(:Leaf => (:TT_used,))
1109+
)
1110+
1111+
# Expectation 24e: same-rate hard dependencies are ignored for auto bindings and canonical publisher checks.
10481112
mapping_hard_same_rate = ModelMapping(
10491113
:Leaf => (
10501114
ModelSpec(MRHardParentModel()) |> TimeStepModel(1.0),
@@ -1058,7 +1122,7 @@ PlantSimEngine.meteo_hint(::Type{<:MRMeteoHintConsumerModel}) = (
10581122
@test input_bindings(spec_hard_same_rate).A.process == :mrhardparent
10591123
@test status(sim_hard_same_rate)[:Leaf][1].B == 5.0
10601124

1061-
# Expectation 24d: different-rate hard dependencies remain strict and require explicit disambiguation.
1125+
# Expectation 24f: different-rate hard dependencies remain strict and require explicit disambiguation.
10621126
mapping_hard_different_rate = ModelMapping(
10631127
:Leaf => (
10641128
ModelSpec(MRHardParentModel()) |> TimeStepModel(1.0),

0 commit comments

Comments
 (0)