@@ -718,6 +718,217 @@ function (pca::PConstructorApplicator)(x::AbstractArray{<:AbstractArray})
718718 return pca. p_constructor (pca .(x))
719719end
720720
721+ """
722+ $TYPEDEF
723+
724+ Callable struct designed for use by `MTKParametersReconstructor`. Uses a fixed set of templates to
725+ act as a very dynamic (and limited) observed function returning an array. See `__apply_copy_template`
726+ for the supported templates.
727+ """
728+ struct CopyParamsByTemplate{T, N}
729+ """
730+ Whether this is a "root" object, in that it directly contains instantiable templates instead
731+ of recursing into other `CopyParamsByTemplate`.
732+ """
733+ isroot:: Bool
734+ """
735+ List of templates.
736+ """
737+ template:: T # TODO : This field is parametric because I thought we might want to specialize it in some cases.
738+ """
739+ Size of the returned buffer.
740+ """
741+ size:: NTuple{N, Int}
742+ end
743+
744+ function __apply_copy_template (valp, @nospecialize (template))
745+ p = parameter_values (valp)
746+ u = state_values (valp)
747+ if template isa ParameterIndex{SciMLStructures. Tunable, UnitRange{Int}}
748+ return p. tunable[template. idx]
749+ elseif template isa ParameterIndex{SciMLStructures. Initials, UnitRange{Int}}
750+ return p. initials[template. idx]
751+ elseif template isa ParameterIndex{SciMLStructures. Discrete, Tuple{Int, UnitRange{Int}}}
752+ return p. discrete[template. idx[1 ]][template. idx[2 ]]
753+ elseif template isa ParameterIndex{SciMLStructures. Constants, Tuple{Int, UnitRange{Int}}}
754+ return p. constant[template. idx[1 ]][template. idx[2 ]]
755+ elseif template isa ParameterIndex{Nonnumeric, Tuple{Int, UnitRange{Int}}}
756+ return p. nonnumeric[template. idx[1 ]][template. idx[2 ]]
757+ elseif template isa UnitRange{Int}
758+ return u[template]
759+ elseif template isa ObservedWrapper
760+ return template (valp)
761+ elseif template isa CopyParamsByTemplate
762+ return template (valp)
763+ else
764+ # MethodError because this is a manual dispatch chain
765+ throw (MethodError (__apply_copy_template, (valp, template)))
766+ end
767+ end
768+
769+ function (cp:: CopyParamsByTemplate )(src)
770+ if cp. isroot
771+ reshape (mapreduce (Base. Fix1 (__apply_copy_template, src), vcat, cp. template), cp. size)
772+ else
773+ reshape (map (Base. Fix1 (__apply_copy_template, src), cp. template), cp. size)
774+ end
775+ end
776+
777+ function CopyParamsByTemplate (srcsys:: AbstractSystem , syms:: AbstractArray{SymbolicT} ; kws... )
778+ template = []
779+ for sym in syms
780+ symidx = parameter_index (srcsys, sym)
781+ if symidx === nothing
782+ symidx = variable_index (srcsys, sym)
783+ if symidx === nothing
784+ if isempty (template)
785+ push! (template, SymbolicT[sym])
786+ continue
787+ end
788+ prev = template[end ]
789+ if prev isa Vector{SymbolicT}
790+ push! (prev, sym)
791+ else
792+ push! (template, SymbolicT[sym])
793+ end
794+ continue
795+ end
796+ if isempty (template)
797+ push! (template, symidx: symidx)
798+ continue
799+ end
800+ prev = template[end ]
801+ if prev isa UnitRange{Int} && last (prev) + 1 == symidx
802+ template[end ] = first (prev): symidx
803+ else
804+ push! (template, symidx: symidx)
805+ end
806+ continue
807+ end
808+ portion = symidx. portion
809+ _bufidx = symidx. idx
810+ if portion isa SciMLStructures. Tunable || portion isa SciMLStructures. Initials
811+
812+ end
813+ bufidx:: UnitRange{Int} = if _bufidx isa AbstractVector{Int}
814+ @assert isequal (vec (_bufidx), first (_bufidx): last (_bufidx))
815+ subidx = nothing
816+ first (_bufidx): last (_bufidx)
817+ elseif _bufidx isa Int
818+ subidx = nothing
819+ _bufidx: _bufidx
820+ elseif _bufidx isa NTuple{2 , Int}
821+ subidx = _bufidx[1 ]
822+ _bufidx[2 ]: _bufidx[2 ]
823+ else
824+ # Will error due to the typeassert on `bufidx`
825+ nothing
826+ end
827+ if isempty (template)
828+ if subidx === nothing
829+ push! (template, ParameterIndex (symidx. portion, bufidx))
830+ else
831+ push! (template, ParameterIndex (symidx. portion, (subidx, bufidx)))
832+ end
833+ continue
834+ end
835+ prev = template[end ]
836+ if prev isa ParameterIndex && prev. portion === symidx. portion && (subidx === nothing && last (prev. idx) + 1 == first (bufidx) || subidx == prev. idx[1 ] && last (prev. idx[2 ]) + 1 == first (bufidx))
837+ if subidx === nothing
838+ template[end ] = ParameterIndex (prev. portion, first (prev. idx): last (bufidx))
839+ else
840+ template[end ] = ParameterIndex (prev. portion, (subidx, first (prev. idx[2 ]): last (bufidx)))
841+ end
842+ elseif subidx === nothing
843+ push! (template, ParameterIndex (symidx. portion, bufidx))
844+ else
845+ push! (template, ParameterIndex (symidx. portion, (subidx, bufidx)))
846+ end
847+ end
848+
849+ for i in eachindex (template)
850+ if template[i] isa Vector{SymbolicT}
851+ template[i] = concrete_getu (srcsys, template[i]; kws... )
852+ end
853+ end
854+ return CopyParamsByTemplate (true , template, size (syms))
855+ end
856+
857+ function CopyParamsByTemplate (srcsys:: AbstractSystem , syms:: AbstractArray ; kws... )
858+ return CopyParamsByTemplate (false , [CopyParamsByTemplate (srcsys, sym; kws... ) for sym in syms], size (syms))
859+ end
860+
861+ struct MTKParametersReconstructor{T, I, D, C, N}
862+ tunables_fn:: T
863+ initials_fn:: I
864+ discretes_fn:: D
865+ consts_fn:: C
866+ nonnumerics_fn:: N
867+ diffcache_buffer_idx:: Int
868+ end
869+
870+ function (recon:: MTKParametersReconstructor )(src, dst)
871+ if state_values (src) === nothing && ! (current_time (src) isa Real && isfinite (current_time (src)))
872+ baseT = Nothing
873+ elseif state_values (src) === nothing
874+ baseT = typeof (current_time (src))
875+ elseif ! (current_time (src) isa Real && isfinite (current_time (src)))
876+ baseT = eltype (state_values (src))
877+ else
878+ baseT = promote_type (typeof (current_time (src)), eltype (state_values (src)))
879+ end
880+ src_ps = parameter_values (src)
881+ dst_ps = parameter_values (dst)
882+ oldcache = dst_ps. caches
883+ # These type-assertions allow `remake` to infer even though `CopyParamsByTemplate` doesn't
884+ # TODO : Find a solution to type-assert constants/discretes/nonnumerics
885+ # TODO : Maybe `CopyParamsByTemplate` could use a specializing tuple-form when `FullSpecialize`?
886+ if dst_ps. tunable isa SVector{0 }
887+ tunablevals = recon. tunables_fn (src)
888+ else
889+ tunable_elT = promote_type (eltype (dst_ps. tunable), eltype (src_ps. tunable))
890+ if baseT != = Nothing
891+ tunable_elT = promote_type (tunable_elT, baseT)
892+ end
893+ baseT = tunable_elT
894+ if ArrayInterface. ismutable (dst_ps. tunable)
895+ tunable_T = Base. promote_op (similar, typeof (dst_ps. tunable), Type{tunable_elT})
896+ tunablevals = recon. tunables_fn (src):: tunable_T
897+ else
898+ tunable_T = StaticArraysCore. similar_type (typeof (dst_ps. tunable), tunable_elT)
899+ tunablevals = recon. tunables_fn (src):: tunable_T
900+ end
901+ end
902+ if dst_ps. initials isa SVector{0 }
903+ initialvals = recon. initials_fn (src)
904+ else
905+ initial_elT = promote_type (eltype (dst_ps. initials), eltype (src_ps. initials))
906+ if baseT != = Nothing
907+ initial_elT = promote_type (initial_elT, baseT)
908+ end
909+ if ArrayInterface. ismutable (dst_ps. initials)
910+ initial_T = Base. promote_op (similar, typeof (dst_ps. initials), Type{initial_elT})
911+ initialvals = recon. initials_fn (src):: initial_T
912+ else
913+ initial_T = StaticArraysCore. similar_type (typeof (dst_ps. initials), initial_elT)
914+ initialvals = recon. initials_fn (src):: initial_T
915+ end
916+ end
917+
918+ nonnumerics = recon. nonnumerics_fn (src):: typeof (dst_ps. nonnumeric)
919+ (; diffcache_buffer_idx) = recon
920+ if ! iszero (diffcache_buffer_idx)
921+ @set! nonnumerics[diffcache_buffer_idx] = DiffCacheAllocatorAPIWrapper {ForwardDiff.valtype(eltype(initialvals))} .(nonnumerics[diffcache_buffer_idx])
922+ end
923+ return promote_with_nothing (
924+ promote_type_with_nothing (eltype (tunablevals), initialvals),
925+ MTKParameters (
926+ tunablevals, initialvals, recon. discretes_fn (src),
927+ recon. consts_fn (src), nonnumerics, oldcache isa Tuple{} ? () : copy .(oldcache)
928+ )
929+ )
930+ end
931+
721932"""
722933 $(TYPEDSIGNATURES)
723934
@@ -732,10 +943,10 @@ takes a value provider of `srcsys` and a value provider of `dstsys` and returns
732943 unwrapped.
733944- `p_constructor`: The `p_constructor` argument to `process_SciMLProblem`.
734945"""
735- function get_mtkparameters_reconstructor (
946+ function MTKParametersReconstructor (
736947 srcsys:: AbstractSystem , dstsys:: AbstractSystem ;
737948 initials = false , unwrap_initials = false , p_constructor = identity,
738- eval_expression = false , eval_module = @__MODULE__ , force_time_independent = false ,
949+ force_time_independent = false ,
739950 kwargs...
740951 )
741952 _p_constructor = p_constructor
@@ -754,32 +965,27 @@ function get_mtkparameters_reconstructor(
754965 tunable_getter = if isempty (tunable_syms)
755966 Returns (SVector {0, Float64} ())
756967 else
757- p_constructor ∘ concrete_getu (
758- srcsys, tunable_syms; eval_expression, eval_module,
759- force_time_independent, kwargs...
760- )
968+ p_constructor ∘ CopyParamsByTemplate (srcsys, tunable_syms; kwargs... )
761969 end
762970 initials_getter = if initials && ! isempty (syms[2 ])
763- initsyms = Vector {Any} ( syms[2 ])
764- allsyms = Set (variable_symbols (srcsys))
971+ initsyms = syms[2 ]:: Vector{SymbolicT}
972+ allsyms = Set {SymbolicT} (variable_symbols (srcsys))
765973 if unwrap_initials
766974 for i in eachindex (initsyms)
767975 sym = initsyms[i]
768- innersym = if operation (sym) === getindex
769- sym, idxs... = arguments (sym)
770- only (arguments (sym))[idxs... ]
976+ arr, isarr = split_indexed_var (sym)
977+ innersym = if isarr
978+ sidx = get_stable_index (sym)
979+ first (arguments (arr))[sidx]
771980 else
772- only (arguments (sym ))
981+ first (arguments (arr ))
773982 end
774983 if innersym in allsyms
775984 initsyms[i] = innersym
776985 end
777986 end
778987 end
779- p_constructor ∘ concrete_getu (
780- srcsys, initsyms; eval_expression, eval_module,
781- force_time_independent, kwargs...
782- )
988+ p_constructor ∘ CopyParamsByTemplate (srcsys, initsyms; kwargs... )
783989 else
784990 Returns (SVector {0, Float64} ())
785991 end
@@ -792,29 +998,25 @@ function get_mtkparameters_reconstructor(
792998 p_constructor (map (x -> x. length, bufsizes))
793999 end
7941000 )
1001+
7951002 # discretes need to be blocked arrays
7961003 # the `getu` returns a tuple of arrays corresponding to `p.discretes`
7971004 # `Base.Fix1(...)` applies `p_constructor` to each of the arrays in the tuple
7981005 # `Base.Fix2(...)` does `BlockedArray.(tuple_of_arrs, blockarrsizes)` returning a
7991006 # tuple of `BlockedArray`s
8001007 Base. Fix2 (Broadcast. BroadcastFunction (BlockedArray), blockarrsizes) ∘
801- Base. Fix1 (broadcast, p_constructor) ∘
1008+ Base. Fix1 (broadcast, p_constructor) ∘ Tuple ∘
8021009 # This `broadcast.(collect, ...)` avoids `ReshapedArray`/`SubArray`s from
8031010 # appearing in the result.
804- concrete_getu (
805- srcsys, Tuple (broadcast .(collect, syms[3 ]));
806- eval_expression, eval_module, force_time_independent, kwargs...
807- )
1011+ CopyParamsByTemplate (srcsys, broadcast .(collect, syms[3 ]); kwargs... )
8081012 end
809- const_getter = if syms[4 ] == ( )
1013+ const_getter = if isempty ( syms[4 ])
8101014 Returns (())
8111015 else
812- Base. Fix1 (broadcast, p_constructor) ∘ concrete_getu (
813- srcsys, Tuple (syms[4 ]);
814- eval_expression, eval_module, force_time_independent, kwargs...
815- )
1016+ Base. Fix1 (broadcast, p_constructor) ∘ Tuple ∘ CopyParamsByTemplate (srcsys, syms[4 ]; kwargs... )
8161017 end
817- nonnumeric_getter = if syms[5 ] == ()
1018+ diffcache_buffer_idx = 0
1019+ nonnumeric_getter = if isempty (syms[5 ])
8181020 Returns (())
8191021 else
8201022 ic = get_index_cache (dstsys)
@@ -825,44 +1027,19 @@ function get_mtkparameters_reconstructor(
8251027 )
8261028
8271029 diffcache_params = SU. getmetadata (dstsys, DiffCacheParams, Dict {SymbolicT, Int} ()):: Dict{SymbolicT, Int}
828- diffcache_buffer_idx = 0
8291030 if ! isempty (diffcache_params)
8301031 representative = first (keys (diffcache_params))
8311032 diffcache_buffer_idx, _ = ic. nonnumeric_idx[representative]
8321033 @set! buftypes[diffcache_buffer_idx] = identity
1034+ for (i, sym) in enumerate (syms[5 ][diffcache_buffer_idx])
1035+ end
8331036 end
8341037 # nonnumerics retain the assigned buffer type without narrowing
8351038 Base. Fix1 (broadcast, _p_constructor) ∘
836- Base. Fix1 (Broadcast. BroadcastFunction (call), buftypes) ∘
837- concrete_getu (
838- srcsys, Tuple (syms[5 ]);
839- eval_expression, eval_module, force_time_independent, kwargs...
840- )
841- end
842- getters = (
843- tunable_getter, initials_getter, discs_getter, const_getter, nonnumeric_getter,
844- )
845- getter = let getters = getters, diffcache_buffer_idx = diffcache_buffer_idx
846- function _getter (valp, initprob)
847- oldcache = parameter_values (initprob). caches
848- tunablevals = getters[1 ](valp)
849- initialvals = getters[2 ](valp)
850- nonnumerics = getters[5 ](valp)
851- if ! iszero (diffcache_buffer_idx)
852- @set! nonnumerics[diffcache_buffer_idx] = DiffCacheAllocatorAPIWrapper {ForwardDiff.valtype(eltype(initialvals))} .(nonnumerics[diffcache_buffer_idx])
853- end
854- return promote_with_nothing (
855- promote_type_with_nothing (eltype (tunablevals), initialvals),
856- MTKParameters (
857- tunablevals, initialvals, getters[3 ](valp),
858- getters[4 ](valp), nonnumerics, oldcache isa Tuple{} ? () :
859- copy .(oldcache)
860- )
861- )
862- end
1039+ Base. Fix1 (Broadcast. BroadcastFunction (call), buftypes) ∘ Tuple ∘ CopyParamsByTemplate (srcsys, syms[5 ]; kwargs... )
8631040 end
8641041
865- return getter
1042+ return MTKParametersReconstructor (tunable_getter, initials_getter, discs_getter, const_getter, nonnumeric_getter, diffcache_buffer_idx)
8661043end
8671044
8681045function call (f, args... )
@@ -890,7 +1067,7 @@ function ReconstructInitializeprob(
8901067 kwargs...
8911068 )
8921069 if is_split (dstsys)
893- pgetter = get_mtkparameters_reconstructor (
1070+ pgetter = MTKParametersReconstructor (
8941071 srcsys, dstsys; p_constructor, eval_expression, eval_module,
8951072 force_time_independent = is_steadystateprob, kwargs...
8961073 )
@@ -970,7 +1147,7 @@ function construct_initializeprobpmap(
9701147 )
9711148 @assert is_initializesystem (initsys)
9721149 if is_split (sys)
973- return let getter = get_mtkparameters_reconstructor (
1150+ return let getter = MTKParametersReconstructor (
9741151 initsys, sys; initials = true , unwrap_initials = true , p_constructor,
9751152 eval_expression, eval_module, kwargs...
9761153 )
0 commit comments