Skip to content

Commit 3518cf5

Browse files
committed
Add interfaces to ModelMapping{SingleScale} so we take the right routes + are able to use it as a ModelList
1 parent ab564e0 commit 3518cf5

7 files changed

Lines changed: 73 additions & 9 deletions

File tree

src/PlantSimEngine.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ export AbstractTimeReducer, MeanWeighted, MeanReducer, SumReducer, MinReducer, M
124124
export OutputCache, HoldLastCache, InterpolateCache, IntegrateCache, AggregateCache
125125
export TemporalState
126126
export OutputRequest, collect_outputs
127-
export ModelList, MultiScaleModel, ModelMapping, MultiScaleMapping, ModelSpec, TimeStepModel, InputBindings, MeteoBindings, MeteoWindow, OutputRouting, ScopeModel
127+
export ModelList, MultiScaleModel, ModelMapping, ModelSpec, TimeStepModel, InputBindings, MeteoBindings, MeteoWindow, OutputRouting, ScopeModel
128128
export resolved_model_specs, explain_model_specs
129129
export RMSE, NRMSE, EF, dr
130130
export Status, TimeStepTable, status

src/checks/dimensions.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,10 @@ function check_dimensions(component::T, w) where {T<:ModelList}
4747
check_dimensions(status(component), w)
4848
end
4949

50+
function check_dimensions(component::ModelMapping{SingleScale}, w)
51+
check_dimensions(status(component), w)
52+
end
53+
5054
# for several components as an array
5155
function check_dimensions(component::T, weather) where {T<:AbstractArray{<:ModelList}}
5256
for i in component

src/component_models/get_status.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,10 @@ function status(m)
4444
m.status
4545
end
4646

47+
status(m::ModelMapping{SingleScale}) = status(m.data)
48+
status(m::ModelMapping{SingleScale}, key::Symbol) = status(m.data, key)
49+
status(m::ModelMapping{SingleScale}, key::T) where {T<:Integer} = status(m.data, key)
50+
4751
function status(m::T) where {T<:AbstractArray{M} where {M}}
4852
[status(i) for i in m]
4953
end

src/mtg/mapping/mapping.jl

Lines changed: 39 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,8 @@ model-specific methods that return a comparable value (for example `Dates.Period
135135
model_rate(::AbstractModel) = nothing
136136
model_rate(model::MultiScaleModel) = model_rate(model_(model))
137137

138-
Base.length(mapping::ModelMapping) = length(mapping.data)
138+
Base.length(mapping::ModelMapping{MultiScale}) = length(mapping.data)
139+
Base.length(::ModelMapping{SingleScale}) = 1
139140
Base.iterate(mapping::ModelMapping{MultiScale}, state...) = iterate(mapping.data, state...)
140141
# Base.iterate(mapping::ModelMapping{SingleScale}, state...) = iterate(mapping.data.models, state...)
141142
Base.show(io::IO, mapping::ModelMapping) = print(io, "ModelMapping with scales: ", join(keys(mapping), ", "))
@@ -154,20 +155,40 @@ end
154155
Base.keys(mapping::ModelMapping) = keys(mapping.data)
155156
Base.values(mapping::ModelMapping) = values(mapping.data)
156157
Base.pairs(mapping::ModelMapping) = pairs(mapping.data)
158+
Base.keys(::ModelMapping{SingleScale}) = ("Default",)
159+
Base.values(mapping::ModelMapping{SingleScale}) = ((values(mapping.data.models)..., status(mapping.data)),)
160+
Base.pairs(mapping::ModelMapping{SingleScale}) = ("Default" => (values(mapping.data.models)..., status(mapping.data)),)
157161
Base.getindex(mapping::ModelMapping, key::String) = mapping.data[key]
158162
Base.getindex(mapping::ModelMapping, key::AbstractString) = mapping.data[String(key)]
163+
function Base.getindex(mapping::ModelMapping{SingleScale}, key::String)
164+
key == "Default" || throw(KeyError(key))
165+
return (values(mapping.data.models)..., status(mapping.data))
166+
end
167+
Base.getindex(mapping::ModelMapping{SingleScale}, key::AbstractString) = getindex(mapping, String(key))
168+
Base.getindex(mapping::ModelMapping{SingleScale}, key::Symbol) = getindex(mapping.data, key)
169+
Base.getindex(mapping::ModelMapping{SingleScale}, key::Integer) = getindex(mapping.data, key)
159170
Base.haskey(mapping::ModelMapping, key::String) = haskey(mapping.data, key)
160171
Base.haskey(mapping::ModelMapping, key::AbstractString) = haskey(mapping.data, String(key))
161172
Base.eltype(::Type{ModelMapping}) = Pair{String,Tuple}
162-
Base.copy(mapping::ModelMapping) = ModelMapping(copy(mapping.data); check=false)
173+
Base.copy(mapping::ModelMapping{MultiScale}) = ModelMapping(copy(mapping.data); check=false)
174+
Base.copy(mapping::ModelMapping{SingleScale}) = ModelMapping{SingleScale,ModelList}(copy(mapping.data))
175+
Base.copy(mapping::ModelMapping{SingleScale}, status) = ModelMapping{SingleScale,ModelList}(copy(mapping.data, status))
163176
Base.Dict(mapping::ModelMapping) = copy(mapping.data)
177+
Base.:(==)(left::ModelMapping{SingleScale}, right::ModelMapping{SingleScale}) = left.data == right.data
178+
179+
function Base.getproperty(mapping::ModelMapping{SingleScale}, name::Symbol)
180+
name === :data && return getfield(mapping, :data)
181+
return getproperty(getfield(mapping, :data), name)
182+
end
164183

165184
function ModelMapping{MultiScale}(mapping::T; check::Bool=true) where {T<:AbstractDict}
166185
normalized = _normalize_multiscale_mapping(mapping)
167186
check && _check_multiscale_mapping!(normalized)
168187
ModelMapping{MultiScale,Dict{String,Tuple}}(normalized)
169188
end
170189

190+
ModelMapping(mapping::AbstractDict; check::Bool=true) = ModelMapping{MultiScale}(mapping; check=check)
191+
171192
ModelMapping(mapping::ModelMapping; check::Bool=true) = check ? ModelMapping(mapping.data; check=true) : mapping
172193

173194
"""
@@ -190,6 +211,12 @@ function ModelMapping(
190211
"No mapping or model was provided. Use `ModelMapping(\"Scale\" => models)` or pass models directly."
191212
)
192213

214+
# Backwards compatibility: allow dict-like construction for type promotion maps,
215+
# e.g. `ModelMapping(Float64 => Float32)`.
216+
if !isempty(args) && all(arg -> arg isa Pair && !(first(arg) isa Union{AbstractString,Symbol}), args)
217+
return Dict(args)
218+
end
219+
193220
if _all_scale_pairs(args)
194221
isempty(processes) || error(
195222
"Cannot mix scale-level pairs with process keyword arguments. ",
@@ -200,7 +227,7 @@ function ModelMapping(
200227
"Provide statuses inside each scale mapping instead."
201228
)
202229
raw_mapping = Dict{String,Any}(String(first(pair)) => last(pair) for pair in args)
203-
return ModelMapping{MultiScale,typeof(raw_mapping)}(raw_mapping; check=check)
230+
return ModelMapping{MultiScale}(raw_mapping; check=check)
204231
end
205232

206233
_contains_scale_like_pair(args) && error(
@@ -215,13 +242,20 @@ function ModelMapping(
215242
end
216243

217244
# Canonical API dispatches for model mappings.
218-
dep(mapping::ModelMapping; verbose::Bool=true) = dep(mapping.data; verbose=verbose)
219-
hard_dependencies(mapping::ModelMapping; verbose::Bool=true) = hard_dependencies(mapping.data; verbose=verbose)
245+
dep(mapping::ModelMapping{SingleScale}; verbose::Bool=true) = dep(mapping.data)
246+
dep(mapping::ModelMapping{MultiScale}; verbose::Bool=true) = dep(mapping.data; verbose=verbose)
247+
hard_dependencies(mapping::ModelMapping{SingleScale}; verbose::Bool=true) = hard_dependencies(mapping.data)
248+
hard_dependencies(mapping::ModelMapping{MultiScale}; verbose::Bool=true) = hard_dependencies(mapping.data; verbose=verbose)
220249
inputs(mapping::ModelMapping) = inputs(mapping.data)
221250
outputs(mapping::ModelMapping) = outputs(mapping.data)
222251
variables(mapping::ModelMapping) = variables(mapping.data)
223252
to_initialize(mapping::ModelMapping, graph=nothing) = to_initialize(mapping.data, graph)
224253
reverse_mapping(mapping::ModelMapping; all=true) = reverse_mapping(mapping.data; all=all)
254+
init_variables(mapping::ModelMapping{SingleScale}; verbose=true) = init_variables(mapping.data; verbose=verbose)
255+
to_initialize(mapping::ModelMapping{SingleScale}) = to_initialize(mapping.data)
256+
to_initialize(mapping::ModelMapping{SingleScale}, graph) = to_initialize(mapping)
257+
pre_allocate_outputs(mapping::ModelMapping{SingleScale}, outs, nsteps; type_promotion=nothing, check=true) =
258+
pre_allocate_outputs(mapping.data, outs, nsteps; type_promotion=type_promotion, check=check)
225259

226260
function _all_scale_pairs(args)
227261
!isempty(args) && all(arg -> arg isa Pair && first(arg) isa Union{AbstractString,Symbol}, args)

src/mtg/mapping/model_generation_from_status_vectors.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,10 @@ function modellist_to_mapping(modellist_original::ModelList, modellist_status; n
254254
return mtg, ModelMapping(mapping), Dict(default_scale => all_vars)
255255
end
256256

257+
function modellist_to_mapping(mapping::ModelMapping{SingleScale}, modellist_status; nsteps=nothing, outputs=nothing)
258+
modellist_to_mapping(mapping.data, modellist_status; nsteps=nsteps, outputs=outputs)
259+
end
260+
257261
function check_statuses_contain_no_remaining_vectors(mapping)
258262
for (organ,models) in mapping
259263

test/helper-functions.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,19 @@ function compare_outputs_modellists(filtered_outputs_1, filtered_outputs_2)
2727
return models_df_sorted_2 == models_df_sorted_1
2828
end
2929

30+
function compare_outputs_modellist_mapping(filtered_outputs_modellist, graphsim)
31+
modellist_df = DataFrame(filtered_outputs_modellist)
32+
modellist_sorted = modellist_df[:, sortperm(names(modellist_df))]
33+
34+
outputs_df = convert_outputs(graphsim.outputs, DataFrame)
35+
@assert haskey(outputs_df, "Default")
36+
common_cols = filter(c -> c in names(outputs_df["Default"]), names(modellist_sorted))
37+
mapping_sorted = outputs_df["Default"][:, sortperm(common_cols)]
38+
modellist_sorted = modellist_sorted[:, sortperm(common_cols)]
39+
40+
return modellist_sorted == mapping_sorted
41+
end
42+
3043
# Breaking this function into two to ensure eval() state synchronisation happens (see comments around the modellist_to_mapping definition)
3144
# Naming could be better
3245
function check_multiscale_simulation_is_equivalent_begin(mapping::ModelMapping, meteo)

test/test-toy_models.jl

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,13 +61,18 @@ end
6161

6262
# Uninitialized:
6363
to_init_uninitialized = to_initialize(ModelMapping(ToyAssimGrowthModel()))
64-
@test haskey(to_init_uninitialized, "Default")
65-
@test :aPPFD in to_init_uninitialized["Default"]
64+
if to_init_uninitialized isa AbstractDict
65+
@test haskey(to_init_uninitialized, "Default")
66+
@test :aPPFD in to_init_uninitialized["Default"]
67+
else
68+
@test :growth in keys(to_init_uninitialized)
69+
@test :aPPFD in to_init_uninitialized[:growth]
70+
end
6671

6772
# One time step:
6873
mapping = ModelMapping(ToyAssimGrowthModel(); status=(aPPFD=30.0,))
6974

70-
@test to_initialize(mapping) == Dict()
75+
@test isempty(to_initialize(mapping))
7176

7277
outputs = run!(mapping)
7378
@test outputs[:biomass] [4.5]

0 commit comments

Comments
 (0)