@@ -718,6 +718,201 @@ function (pca::PConstructorApplicator)(x::AbstractArray{<:AbstractArray})
718718 return pca. p_constructor (pca .(x))
719719end
720720
721+ """
722+ $TYPEDEF
723+
724+ Callable struct designed for use by `MTKParametersReconstructor`. Uses a fixed set of templates to
725+ act as a very dynamic (and limited) observed function returning an array. See `__apply_copy_template`
726+ for the supported templates.
727+ """
728+ struct CopyParamsByTemplate{T, N}
729+ """
730+ Whether this is a "root" object, in that it directly contains instantiable templates instead
731+ of recursing into other `CopyParamsByTemplate`.
732+ """
733+ isroot:: Bool
734+ """
735+ List of templates.
736+ """
737+ template:: T # TODO : This field is parametric because I thought we might want to specialize it in some cases.
738+ """
739+ Size of the returned buffer.
740+ """
741+ size:: NTuple{N, Int}
742+ end
743+
744+ function __apply_copy_template (valp, @nospecialize (template))
745+ p = parameter_values (valp)
746+ u = state_values (valp)
747+ if template isa ParameterIndex{SciMLStructures. Tunable, UnitRange{Int}}
748+ return p. tunable[template. idx]
749+ elseif template isa ParameterIndex{SciMLStructures. Initials, UnitRange{Int}}
750+ return p. initials[template. idx]
751+ elseif template isa ParameterIndex{SciMLStructures. Discrete, Tuple{Int, UnitRange{Int}}}
752+ return p. discrete[template. idx[1 ]][template. idx[2 ]]
753+ elseif template isa ParameterIndex{SciMLStructures. Constants, Tuple{Int, UnitRange{Int}}}
754+ return p. constant[template. idx[1 ]][template. idx[2 ]]
755+ elseif template isa ParameterIndex{Nonnumeric, Tuple{Int, UnitRange{Int}}}
756+ return p. nonnumeric[template. idx[1 ]][template. idx[2 ]]
757+ elseif template isa UnitRange{Int}
758+ return u[template]
759+ elseif template isa ObservedWrapper
760+ return template (valp)
761+ elseif template isa CopyParamsByTemplate
762+ return template (valp)
763+ else
764+ # MethodError because this is a manual dispatch chain
765+ throw (MethodError (__apply_copy_template, (valp, template)))
766+ end
767+ end
768+
769+ function (cp:: CopyParamsByTemplate )(src)
770+ if cp. isroot
771+ reshape (mapreduce (Base. Fix1 (__apply_copy_template, src), vcat, cp. template), cp. size)
772+ else
773+ reshape (map (Base. Fix1 (__apply_copy_template, src), cp. template), cp. size)
774+ end
775+ end
776+
777+ function CopyParamsByTemplate (srcsys:: AbstractSystem , syms:: AbstractArray{SymbolicT} ; kws... )
778+ template = []
779+ for sym in syms
780+ symidx = parameter_index (srcsys, sym)
781+ if symidx === nothing
782+ symidx = variable_index (srcsys, sym)
783+ if symidx === nothing
784+ if isempty (template)
785+ push! (template, SymbolicT[sym])
786+ continue
787+ end
788+ prev = template[end ]
789+ if prev isa Vector{SymbolicT}
790+ push! (prev, sym)
791+ else
792+ push! (template, SymbolicT[sym])
793+ end
794+ continue
795+ end
796+ if isempty (template)
797+ push! (template, symidx: symidx)
798+ continue
799+ end
800+ prev = template[end ]
801+ if prev isa UnitRange{Int} && last (prev) + 1 == symidx
802+ template[end ] = first (prev): symidx
803+ else
804+ push! (template, symidx: symidx)
805+ end
806+ continue
807+ end
808+ portion = symidx. portion
809+ _bufidx = symidx. idx
810+ if portion isa SciMLStructures. Tunable || portion isa SciMLStructures. Initials
811+
812+ end
813+ bufidx:: UnitRange{Int} = if _bufidx isa AbstractVector{Int}
814+ @assert isequal (vec (_bufidx), first (_bufidx): last (_bufidx))
815+ subidx = nothing
816+ first (_bufidx): last (_bufidx)
817+ elseif _bufidx isa Int
818+ subidx = nothing
819+ _bufidx: _bufidx
820+ elseif _bufidx isa NTuple{2 , Int}
821+ subidx = _bufidx[1 ]
822+ _bufidx[2 ]: _bufidx[2 ]
823+ else
824+ # Will error due to the typeassert on `bufidx`
825+ nothing
826+ end
827+ if isempty (template)
828+ if subidx === nothing
829+ push! (template, ParameterIndex (symidx. portion, bufidx))
830+ else
831+ push! (template, ParameterIndex (symidx. portion, (subidx, bufidx)))
832+ end
833+ continue
834+ end
835+ prev = template[end ]
836+ if prev isa ParameterIndex && prev. portion === symidx. portion && (subidx === nothing && last (prev. idx) + 1 == first (bufidx) || subidx == prev. idx[1 ] && last (prev. idx[2 ]) + 1 == first (bufidx))
837+ if subidx === nothing
838+ template[end ] = ParameterIndex (prev. portion, first (prev. idx): last (bufidx))
839+ else
840+ template[end ] = ParameterIndex (prev. portion, (subidx, first (prev. idx[2 ]): last (bufidx)))
841+ end
842+ elseif subidx === nothing
843+ push! (template, ParameterIndex (symidx. portion, bufidx))
844+ else
845+ push! (template, ParameterIndex (symidx. portion, (subidx, bufidx)))
846+ end
847+ end
848+
849+ for i in eachindex (template)
850+ if template[i] isa Vector{SymbolicT}
851+ template[i] = concrete_getu (srcsys, template[i]; kws... )
852+ end
853+ end
854+ return CopyParamsByTemplate (true , template, size (syms))
855+ end
856+
857+ function CopyParamsByTemplate (srcsys:: AbstractSystem , syms:: AbstractArray ; kws... )
858+ return CopyParamsByTemplate (false , [CopyParamsByTemplate (srcsys, sym; kws... ) for sym in syms], size (syms))
859+ end
860+
861+ struct MTKParametersReconstructor{T, I, D, C, N}
862+ tunables_fn:: T
863+ initials_fn:: I
864+ discretes_fn:: D
865+ consts_fn:: C
866+ nonnumerics_fn:: N
867+ diffcache_buffer_idx:: Int
868+ end
869+
870+ function (recon:: MTKParametersReconstructor )(src, dst)
871+ src_ps = parameter_values (src)
872+ dst_ps = parameter_values (dst)
873+ oldcache = dst_ps. caches
874+ # These type-assertions allow `remake` to infer even though `CopyParamsByTemplate` doesn't
875+ # TODO : Find a solution to type-assert constants/discretes/nonnumerics
876+ # TODO : Maybe `CopyParamsByTemplate` could use a specializing tuple-form when `FullSpecialize`?
877+ if dst_ps. tunable isa SVector{0 }
878+ tunablevals = recon. tunables_fn (src)
879+ else
880+ tunable_elT = promote_type (eltype (dst_ps. tunable), eltype (src_ps. tunable))
881+ if ArrayInterface. ismutable (dst_ps. tunable)
882+ tunable_T = Base. promote_op (similar, typeof (dst_ps. tunable), Type{tunable_elT})
883+ tunablevals = recon. tunables_fn (src):: tunable_T
884+ else
885+ tunable_T = StaticArraysCore. similar_type (typeof (dst_ps. tunable), tunable_elT)
886+ tunablevals = recon. tunables_fn (src):: tunable_T
887+ end
888+ end
889+ if dst_ps. initials isa SVector{0 }
890+ initialvals = recon. initials_fn (src)
891+ else
892+ initial_elT = promote_type (eltype (dst_ps. initials), eltype (src_ps. initials))
893+ if ArrayInterface. ismutable (dst_ps. initials)
894+ initial_T = Base. promote_op (similar, typeof (dst_ps. initials), Type{initial_elT})
895+ initialvals = recon. initials_fn (src):: initial_T
896+ else
897+ initial_T = StaticArraysCore. similar_type (typeof (dst_ps. initials), initial_elT)
898+ initialvals = recon. initials_fn (src):: initial_T
899+ end
900+ end
901+
902+ nonnumerics = recon. nonnumerics_fn (src)
903+ (; diffcache_buffer_idx) = recon
904+ if ! iszero (diffcache_buffer_idx)
905+ @set! nonnumerics[diffcache_buffer_idx] = DiffCacheAllocatorAPIWrapper {ForwardDiff.valtype(eltype(initialvals))} .(nonnumerics[diffcache_buffer_idx])
906+ end
907+ return promote_with_nothing (
908+ promote_type_with_nothing (eltype (tunablevals), initialvals),
909+ MTKParameters (
910+ tunablevals, initialvals, recon. discretes_fn (src),
911+ recon. consts_fn (src), nonnumerics, oldcache isa Tuple{} ? () : copy .(oldcache)
912+ )
913+ )
914+ end
915+
721916"""
722917 $(TYPEDSIGNATURES)
723918
@@ -732,10 +927,10 @@ takes a value provider of `srcsys` and a value provider of `dstsys` and returns
732927 unwrapped.
733928- `p_constructor`: The `p_constructor` argument to `process_SciMLProblem`.
734929"""
735- function get_mtkparameters_reconstructor (
930+ function MTKParametersReconstructor (
736931 srcsys:: AbstractSystem , dstsys:: AbstractSystem ;
737932 initials = false , unwrap_initials = false , p_constructor = identity,
738- eval_expression = false , eval_module = @__MODULE__ , force_time_independent = false ,
933+ force_time_independent = false ,
739934 kwargs...
740935 )
741936 _p_constructor = p_constructor
@@ -754,32 +949,27 @@ function get_mtkparameters_reconstructor(
754949 tunable_getter = if isempty (tunable_syms)
755950 Returns (SVector {0, Float64} ())
756951 else
757- p_constructor ∘ concrete_getu (
758- srcsys, tunable_syms; eval_expression, eval_module,
759- force_time_independent, kwargs...
760- )
952+ p_constructor ∘ CopyParamsByTemplate (srcsys, tunable_syms; kwargs... )
761953 end
762954 initials_getter = if initials && ! isempty (syms[2 ])
763- initsyms = Vector {Any} ( syms[2 ])
764- allsyms = Set (variable_symbols (srcsys))
955+ initsyms = syms[2 ]:: Vector{SymbolicT}
956+ allsyms = Set {SymbolicT} (variable_symbols (srcsys))
765957 if unwrap_initials
766958 for i in eachindex (initsyms)
767959 sym = initsyms[i]
768- innersym = if operation (sym) === getindex
769- sym, idxs... = arguments (sym)
770- only (arguments (sym))[idxs... ]
960+ arr, isarr = split_indexed_var (sym)
961+ innersym = if isarr
962+ sidx = get_stable_index (sym)
963+ first (arguments (arr))[sidx]
771964 else
772- only (arguments (sym ))
965+ first (arguments (arr ))
773966 end
774967 if innersym in allsyms
775968 initsyms[i] = innersym
776969 end
777970 end
778971 end
779- p_constructor ∘ concrete_getu (
780- srcsys, initsyms; eval_expression, eval_module,
781- force_time_independent, kwargs...
782- )
972+ p_constructor ∘ CopyParamsByTemplate (srcsys, initsyms; kwargs... )
783973 else
784974 Returns (SVector {0, Float64} ())
785975 end
@@ -792,29 +982,25 @@ function get_mtkparameters_reconstructor(
792982 p_constructor (map (x -> x. length, bufsizes))
793983 end
794984 )
985+
795986 # discretes need to be blocked arrays
796987 # the `getu` returns a tuple of arrays corresponding to `p.discretes`
797988 # `Base.Fix1(...)` applies `p_constructor` to each of the arrays in the tuple
798989 # `Base.Fix2(...)` does `BlockedArray.(tuple_of_arrs, blockarrsizes)` returning a
799990 # tuple of `BlockedArray`s
800991 Base. Fix2 (Broadcast. BroadcastFunction (BlockedArray), blockarrsizes) ∘
801- Base. Fix1 (broadcast, p_constructor) ∘
992+ Base. Fix1 (broadcast, p_constructor) ∘ Tuple ∘
802993 # This `broadcast.(collect, ...)` avoids `ReshapedArray`/`SubArray`s from
803994 # appearing in the result.
804- concrete_getu (
805- srcsys, Tuple (broadcast .(collect, syms[3 ]));
806- eval_expression, eval_module, force_time_independent, kwargs...
807- )
995+ CopyParamsByTemplate (srcsys, broadcast .(collect, syms[3 ]); kwargs... )
808996 end
809- const_getter = if syms[4 ] == ( )
997+ const_getter = if isempty ( syms[4 ])
810998 Returns (())
811999 else
812- Base. Fix1 (broadcast, p_constructor) ∘ concrete_getu (
813- srcsys, Tuple (syms[4 ]);
814- eval_expression, eval_module, force_time_independent, kwargs...
815- )
1000+ Base. Fix1 (broadcast, p_constructor) ∘ Tuple ∘ CopyParamsByTemplate (srcsys, syms[4 ]; kwargs... )
8161001 end
817- nonnumeric_getter = if syms[5 ] == ()
1002+ diffcache_buffer_idx = 0
1003+ nonnumeric_getter = if isempty (syms[5 ])
8181004 Returns (())
8191005 else
8201006 ic = get_index_cache (dstsys)
@@ -825,44 +1011,19 @@ function get_mtkparameters_reconstructor(
8251011 )
8261012
8271013 diffcache_params = SU. getmetadata (dstsys, DiffCacheParams, Dict {SymbolicT, Int} ()):: Dict{SymbolicT, Int}
828- diffcache_buffer_idx = 0
8291014 if ! isempty (diffcache_params)
8301015 representative = first (keys (diffcache_params))
8311016 diffcache_buffer_idx, _ = ic. nonnumeric_idx[representative]
8321017 @set! buftypes[diffcache_buffer_idx] = identity
1018+ for (i, sym) in enumerate (syms[5 ][diffcache_buffer_idx])
1019+ end
8331020 end
8341021 # nonnumerics retain the assigned buffer type without narrowing
8351022 Base. Fix1 (broadcast, _p_constructor) ∘
836- Base. Fix1 (Broadcast. BroadcastFunction (call), buftypes) ∘
837- concrete_getu (
838- srcsys, Tuple (syms[5 ]);
839- eval_expression, eval_module, force_time_independent, kwargs...
840- )
841- end
842- getters = (
843- tunable_getter, initials_getter, discs_getter, const_getter, nonnumeric_getter,
844- )
845- getter = let getters = getters, diffcache_buffer_idx = diffcache_buffer_idx
846- function _getter (valp, initprob)
847- oldcache = parameter_values (initprob). caches
848- tunablevals = getters[1 ](valp)
849- initialvals = getters[2 ](valp)
850- nonnumerics = getters[5 ](valp)
851- if ! iszero (diffcache_buffer_idx)
852- @set! nonnumerics[diffcache_buffer_idx] = DiffCacheAllocatorAPIWrapper {ForwardDiff.valtype(eltype(initialvals))} .(nonnumerics[diffcache_buffer_idx])
853- end
854- return promote_with_nothing (
855- promote_type_with_nothing (eltype (tunablevals), initialvals),
856- MTKParameters (
857- tunablevals, initialvals, getters[3 ](valp),
858- getters[4 ](valp), nonnumerics, oldcache isa Tuple{} ? () :
859- copy .(oldcache)
860- )
861- )
862- end
1023+ Base. Fix1 (Broadcast. BroadcastFunction (call), buftypes) ∘ Tuple ∘ CopyParamsByTemplate (srcsys, syms[5 ]; kwargs... )
8631024 end
8641025
865- return getter
1026+ return MTKParametersReconstructor (tunable_getter, initials_getter, discs_getter, const_getter, nonnumeric_getter, diffcache_buffer_idx)
8661027end
8671028
8681029function call (f, args... )
@@ -890,7 +1051,7 @@ function ReconstructInitializeprob(
8901051 kwargs...
8911052 )
8921053 if is_split (dstsys)
893- pgetter = get_mtkparameters_reconstructor (
1054+ pgetter = MTKParametersReconstructor (
8941055 srcsys, dstsys; p_constructor, eval_expression, eval_module,
8951056 force_time_independent = is_steadystateprob, kwargs...
8961057 )
@@ -970,7 +1131,7 @@ function construct_initializeprobpmap(
9701131 )
9711132 @assert is_initializesystem (initsys)
9721133 if is_split (sys)
973- return let getter = get_mtkparameters_reconstructor (
1134+ return let getter = MTKParametersReconstructor (
9741135 initsys, sys; initials = true , unwrap_initials = true , p_constructor,
9751136 eval_expression, eval_module, kwargs...
9761137 )
0 commit comments