@@ -718,6 +718,226 @@ function (pca::PConstructorApplicator)(x::AbstractArray{<:AbstractArray})
718718 return pca. p_constructor (pca .(x))
719719end
720720
721+ """
722+ $TYPEDEF
723+
724+ Callable struct designed for use by `MTKParametersReconstructor`. Uses a fixed set of templates to
725+ act as a very dynamic (and limited) observed function returning an array. See `__apply_copy_template`
726+ for the supported templates.
727+ """
728+ struct CopyParamsByTemplate{T, N}
729+ """
730+ Whether this is a "root" object, in that it directly contains instantiable templates instead
731+ of recursing into other `CopyParamsByTemplate`.
732+ """
733+ isroot:: Bool
734+ """
735+ List of templates.
736+ """
737+ template:: T # TODO : This field is parametric because I thought we might want to specialize it in some cases.
738+ """
739+ Size of the returned buffer.
740+ """
741+ size:: NTuple{N, Int}
742+ end
743+
744+ function __apply_copy_template (valp, @nospecialize (template))
745+ p = parameter_values (valp)
746+ u = state_values (valp)
747+ if template isa ParameterIndex{SciMLStructures. Tunable, UnitRange{Int}}
748+ if p isa MTKParameters
749+ return p. tunable[template. idx]
750+ else
751+ return p[template. idx]
752+ end
753+ elseif template isa ParameterIndex{SciMLStructures. Initials, UnitRange{Int}}
754+ return p. initials[template. idx]
755+ elseif template isa ParameterIndex{SciMLStructures. Discrete, Tuple{Int, UnitRange{Int}}}
756+ return p. discrete[template. idx[1 ]][template. idx[2 ]]
757+ elseif template isa ParameterIndex{SciMLStructures. Constants, Tuple{Int, UnitRange{Int}}}
758+ return p. constant[template. idx[1 ]][template. idx[2 ]]
759+ elseif template isa ParameterIndex{Nonnumeric, Tuple{Int, UnitRange{Int}}}
760+ return p. nonnumeric[template. idx[1 ]][template. idx[2 ]]
761+ elseif template isa UnitRange{Int}
762+ return u[template]
763+ elseif template isa ObservedWrapper
764+ return template (valp)
765+ elseif template isa CopyParamsByTemplate
766+ return template (valp)
767+ else
768+ # MethodError because this is a manual dispatch chain
769+ throw (MethodError (__apply_copy_template, (valp, template)))
770+ end
771+ end
772+
773+ function (cp:: CopyParamsByTemplate )(src)
774+ if cp. isroot
775+ reshape (mapreduce (Base. Fix1 (__apply_copy_template, src), vcat, cp. template), cp. size)
776+ else
777+ reshape (map (Base. Fix1 (__apply_copy_template, src), cp. template), cp. size)
778+ end
779+ end
780+
781+ function CopyParamsByTemplate (srcsys:: AbstractSystem , syms:: AbstractArray{SymbolicT} ; kws... )
782+ template = []
783+ for sym in syms
784+ symidx = parameter_index (srcsys, sym)
785+ if symidx === nothing
786+ symidx = variable_index (srcsys, sym)
787+ if symidx === nothing
788+ if isempty (template)
789+ push! (template, SymbolicT[sym])
790+ continue
791+ end
792+ prev = template[end ]
793+ if prev isa Vector{SymbolicT}
794+ push! (prev, sym)
795+ else
796+ push! (template, SymbolicT[sym])
797+ end
798+ continue
799+ end
800+ if isempty (template)
801+ push! (template, symidx: symidx)
802+ continue
803+ end
804+ prev = template[end ]
805+ if prev isa UnitRange{Int} && last (prev) + 1 == symidx
806+ template[end ] = first (prev): symidx
807+ else
808+ push! (template, symidx: symidx)
809+ end
810+ continue
811+ elseif symidx isa Int
812+ symidx = ParameterIndex (SciMLStructures. Tunable (), symidx)
813+ end
814+ portion = symidx. portion
815+ _bufidx = symidx. idx
816+ bufidx:: UnitRange{Int} = if _bufidx isa AbstractVector{Int}
817+ @assert isequal (vec (_bufidx), first (_bufidx): last (_bufidx))
818+ subidx = nothing
819+ first (_bufidx): last (_bufidx)
820+ elseif _bufidx isa Int
821+ subidx = nothing
822+ _bufidx: _bufidx
823+ elseif _bufidx isa NTuple{2 , Int}
824+ subidx = _bufidx[1 ]
825+ _bufidx[2 ]: _bufidx[2 ]
826+ else
827+ # Will error due to the typeassert on `bufidx`
828+ nothing
829+ end
830+ if isempty (template)
831+ if subidx === nothing
832+ push! (template, ParameterIndex (symidx. portion, bufidx))
833+ else
834+ push! (template, ParameterIndex (symidx. portion, (subidx, bufidx)))
835+ end
836+ continue
837+ end
838+ prev = template[end ]
839+ if prev isa ParameterIndex && prev. portion === symidx. portion && (subidx === nothing && last (prev. idx) + 1 == first (bufidx) || subidx == prev. idx[1 ] && last (prev. idx[2 ]) + 1 == first (bufidx))
840+ if subidx === nothing
841+ template[end ] = ParameterIndex (prev. portion, first (prev. idx): last (bufidx))
842+ else
843+ template[end ] = ParameterIndex (prev. portion, (subidx, first (prev. idx[2 ]): last (bufidx)))
844+ end
845+ elseif subidx === nothing
846+ push! (template, ParameterIndex (symidx. portion, bufidx))
847+ else
848+ push! (template, ParameterIndex (symidx. portion, (subidx, bufidx)))
849+ end
850+ end
851+
852+ for i in eachindex (template)
853+ if template[i] isa Vector{SymbolicT}
854+ template[i] = concrete_getu (srcsys, template[i]; kws... )
855+ end
856+ end
857+ return CopyParamsByTemplate (true , template, size (syms))
858+ end
859+
860+ function CopyParamsByTemplate (srcsys:: AbstractSystem , syms:: AbstractArray ; kws... )
861+ return CopyParamsByTemplate (false , [CopyParamsByTemplate (srcsys, sym; kws... ) for sym in syms], size (syms))
862+ end
863+
864+ struct MTKParametersReconstructor{T, I, D, C, N}
865+ tunables_fn:: T
866+ initials_fn:: I
867+ discretes_fn:: D
868+ consts_fn:: C
869+ nonnumerics_fn:: N
870+ diffcache_buffer_idx:: Int
871+ end
872+
873+ function (recon:: MTKParametersReconstructor )(src, dst)
874+ if state_values (src) === nothing && ! (applicable (current_time, src) && current_time (src) isa Real && isfinite (current_time (src)))
875+ baseT = Nothing
876+ elseif state_values (src) === nothing
877+ baseT = typeof (current_time (src))
878+ elseif ! (applicable (current_time, src) && current_time (src) isa Real && isfinite (current_time (src)))
879+ baseT = eltype (state_values (src))
880+ else
881+ baseT = promote_type (typeof (current_time (src)), eltype (state_values (src)))
882+ end
883+ src_ps = parameter_values (src)
884+ dst_ps = parameter_values (dst)
885+ oldcache = dst_ps. caches
886+ # These type-assertions allow `remake` to infer even though `CopyParamsByTemplate` doesn't
887+ # TODO : Find a solution to type-assert constants/discretes/nonnumerics
888+ # TODO : Maybe `CopyParamsByTemplate` could use a specializing tuple-form when `FullSpecialize`?
889+ if dst_ps. tunable isa SVector{0 }
890+ tunablevals = recon. tunables_fn (src)
891+ else
892+ tunable_elT = eltype (dst_ps. tunable)
893+ if ! (src_ps. tunable isa SVector{0 })
894+ tunable_elT = promote_type (tunable_elT, eltype (src_ps. tunable))
895+ end
896+ if baseT != = Nothing
897+ tunable_elT = promote_type (tunable_elT, baseT)
898+ end
899+ baseT = tunable_elT
900+ if ArrayInterface. ismutable (dst_ps. tunable)
901+ tunable_T = Base. promote_op (similar, typeof (dst_ps. tunable), Type{tunable_elT})
902+ tunablevals = recon. tunables_fn (src):: tunable_T
903+ else
904+ tunable_T = StaticArraysCore. similar_type (typeof (dst_ps. tunable), tunable_elT)
905+ tunablevals = recon. tunables_fn (src):: tunable_T
906+ end
907+ end
908+ if dst_ps. initials isa SVector{0 }
909+ initialvals = recon. initials_fn (src)
910+ else
911+ initial_elT = eltype (dst_ps. initials)
912+ if ! (src_ps. initials isa SVector{0 })
913+ initial_elT = promote_type (initial_elT, eltype (src_ps. initials))
914+ end
915+ if baseT != = Nothing
916+ initial_elT = promote_type (initial_elT, baseT)
917+ end
918+ if ArrayInterface. ismutable (dst_ps. initials)
919+ initial_T = Base. promote_op (similar, typeof (dst_ps. initials), Type{initial_elT})
920+ initialvals = recon. initials_fn (src):: initial_T
921+ else
922+ initial_T = StaticArraysCore. similar_type (typeof (dst_ps. initials), initial_elT)
923+ initialvals = recon. initials_fn (src):: initial_T
924+ end
925+ end
926+
927+ nonnumerics = recon. nonnumerics_fn (src):: typeof (dst_ps. nonnumeric)
928+ (; diffcache_buffer_idx) = recon
929+ if ! iszero (diffcache_buffer_idx)
930+ @set! nonnumerics[diffcache_buffer_idx] = DiffCacheAllocatorAPIWrapper {ForwardDiff.valtype(eltype(initialvals))} .(nonnumerics[diffcache_buffer_idx])
931+ end
932+ return promote_with_nothing (
933+ promote_type_with_nothing (eltype (tunablevals), initialvals),
934+ MTKParameters (
935+ tunablevals, initialvals, recon. discretes_fn (src),
936+ recon. consts_fn (src), nonnumerics, oldcache isa Tuple{} ? () : copy .(oldcache)
937+ )
938+ )
939+ end
940+
721941"""
722942 $(TYPEDSIGNATURES)
723943
@@ -732,10 +952,10 @@ takes a value provider of `srcsys` and a value provider of `dstsys` and returns
732952 unwrapped.
733953- `p_constructor`: The `p_constructor` argument to `process_SciMLProblem`.
734954"""
735- function get_mtkparameters_reconstructor (
955+ function MTKParametersReconstructor (
736956 srcsys:: AbstractSystem , dstsys:: AbstractSystem ;
737957 initials = false , unwrap_initials = false , p_constructor = identity,
738- eval_expression = false , eval_module = @__MODULE__ , force_time_independent = false ,
958+ force_time_independent = false ,
739959 kwargs...
740960 )
741961 _p_constructor = p_constructor
@@ -754,32 +974,27 @@ function get_mtkparameters_reconstructor(
754974 tunable_getter = if isempty (tunable_syms)
755975 Returns (SVector {0, Float64} ())
756976 else
757- p_constructor ∘ concrete_getu (
758- srcsys, tunable_syms; eval_expression, eval_module,
759- force_time_independent, kwargs...
760- )
977+ p_constructor ∘ CopyParamsByTemplate (srcsys, tunable_syms; kwargs... )
761978 end
762979 initials_getter = if initials && ! isempty (syms[2 ])
763- initsyms = Vector {Any} ( syms[2 ])
764- allsyms = Set (variable_symbols (srcsys))
980+ initsyms = syms[2 ]:: Vector{SymbolicT}
981+ allsyms = Set {SymbolicT} (variable_symbols (srcsys))
765982 if unwrap_initials
766983 for i in eachindex (initsyms)
767984 sym = initsyms[i]
768- innersym = if operation (sym) === getindex
769- sym, idxs... = arguments (sym)
770- only (arguments (sym))[idxs... ]
985+ arr, isarr = split_indexed_var (sym)
986+ innersym = if isarr
987+ sidx = get_stable_index (sym)
988+ first (arguments (arr))[sidx]
771989 else
772- only (arguments (sym ))
990+ first (arguments (arr ))
773991 end
774992 if innersym in allsyms
775993 initsyms[i] = innersym
776994 end
777995 end
778996 end
779- p_constructor ∘ concrete_getu (
780- srcsys, initsyms; eval_expression, eval_module,
781- force_time_independent, kwargs...
782- )
997+ p_constructor ∘ CopyParamsByTemplate (srcsys, initsyms; kwargs... )
783998 else
784999 Returns (SVector {0, Float64} ())
7851000 end
@@ -792,29 +1007,25 @@ function get_mtkparameters_reconstructor(
7921007 p_constructor (map (x -> x. length, bufsizes))
7931008 end
7941009 )
1010+
7951011 # discretes need to be blocked arrays
7961012 # the `getu` returns a tuple of arrays corresponding to `p.discretes`
7971013 # `Base.Fix1(...)` applies `p_constructor` to each of the arrays in the tuple
7981014 # `Base.Fix2(...)` does `BlockedArray.(tuple_of_arrs, blockarrsizes)` returning a
7991015 # tuple of `BlockedArray`s
8001016 Base. Fix2 (Broadcast. BroadcastFunction (BlockedArray), blockarrsizes) ∘
801- Base. Fix1 (broadcast, p_constructor) ∘
1017+ Base. Fix1 (broadcast, p_constructor) ∘ Tuple ∘
8021018 # This `broadcast.(collect, ...)` avoids `ReshapedArray`/`SubArray`s from
8031019 # appearing in the result.
804- concrete_getu (
805- srcsys, Tuple (broadcast .(collect, syms[3 ]));
806- eval_expression, eval_module, force_time_independent, kwargs...
807- )
1020+ CopyParamsByTemplate (srcsys, broadcast .(collect, syms[3 ]); kwargs... )
8081021 end
809- const_getter = if syms[4 ] == ( )
1022+ const_getter = if isempty ( syms[4 ])
8101023 Returns (())
8111024 else
812- Base. Fix1 (broadcast, p_constructor) ∘ concrete_getu (
813- srcsys, Tuple (syms[4 ]);
814- eval_expression, eval_module, force_time_independent, kwargs...
815- )
1025+ Base. Fix1 (broadcast, p_constructor) ∘ Tuple ∘ CopyParamsByTemplate (srcsys, syms[4 ]; kwargs... )
8161026 end
817- nonnumeric_getter = if syms[5 ] == ()
1027+ diffcache_buffer_idx = 0
1028+ nonnumeric_getter = if isempty (syms[5 ])
8181029 Returns (())
8191030 else
8201031 ic = get_index_cache (dstsys)
@@ -825,44 +1036,19 @@ function get_mtkparameters_reconstructor(
8251036 )
8261037
8271038 diffcache_params = SU. getmetadata (dstsys, DiffCacheParams, Dict {SymbolicT, Int} ()):: Dict{SymbolicT, Int}
828- diffcache_buffer_idx = 0
8291039 if ! isempty (diffcache_params)
8301040 representative = first (keys (diffcache_params))
8311041 diffcache_buffer_idx, _ = ic. nonnumeric_idx[representative]
8321042 @set! buftypes[diffcache_buffer_idx] = identity
1043+ for (i, sym) in enumerate (syms[5 ][diffcache_buffer_idx])
1044+ end
8331045 end
8341046 # nonnumerics retain the assigned buffer type without narrowing
8351047 Base. Fix1 (broadcast, _p_constructor) ∘
836- Base. Fix1 (Broadcast. BroadcastFunction (call), buftypes) ∘
837- concrete_getu (
838- srcsys, Tuple (syms[5 ]);
839- eval_expression, eval_module, force_time_independent, kwargs...
840- )
841- end
842- getters = (
843- tunable_getter, initials_getter, discs_getter, const_getter, nonnumeric_getter,
844- )
845- getter = let getters = getters, diffcache_buffer_idx = diffcache_buffer_idx
846- function _getter (valp, initprob)
847- oldcache = parameter_values (initprob). caches
848- tunablevals = getters[1 ](valp)
849- initialvals = getters[2 ](valp)
850- nonnumerics = getters[5 ](valp)
851- if ! iszero (diffcache_buffer_idx)
852- @set! nonnumerics[diffcache_buffer_idx] = DiffCacheAllocatorAPIWrapper {ForwardDiff.valtype(eltype(initialvals))} .(nonnumerics[diffcache_buffer_idx])
853- end
854- return promote_with_nothing (
855- promote_type_with_nothing (eltype (tunablevals), initialvals),
856- MTKParameters (
857- tunablevals, initialvals, getters[3 ](valp),
858- getters[4 ](valp), nonnumerics, oldcache isa Tuple{} ? () :
859- copy .(oldcache)
860- )
861- )
862- end
1048+ Base. Fix1 (Broadcast. BroadcastFunction (call), buftypes) ∘ Tuple ∘ CopyParamsByTemplate (srcsys, syms[5 ]; kwargs... )
8631049 end
8641050
865- return getter
1051+ return MTKParametersReconstructor (tunable_getter, initials_getter, discs_getter, const_getter, nonnumeric_getter, diffcache_buffer_idx)
8661052end
8671053
8681054function call (f, args... )
@@ -890,7 +1076,7 @@ function ReconstructInitializeprob(
8901076 kwargs...
8911077 )
8921078 if is_split (dstsys)
893- pgetter = get_mtkparameters_reconstructor (
1079+ pgetter = MTKParametersReconstructor (
8941080 srcsys, dstsys; p_constructor, eval_expression, eval_module,
8951081 force_time_independent = is_steadystateprob, kwargs...
8961082 )
@@ -970,7 +1156,7 @@ function construct_initializeprobpmap(
9701156 )
9711157 @assert is_initializesystem (initsys)
9721158 if is_split (sys)
973- return let getter = get_mtkparameters_reconstructor (
1159+ return let getter = MTKParametersReconstructor (
9741160 initsys, sys; initials = true , unwrap_initials = true , p_constructor,
9751161 eval_expression, eval_module, kwargs...
9761162 )
0 commit comments