66import jax
77import pandas as pd
88from dags import concatenate_functions , with_signature
9- from dags .tree import qname_from_tree_path , tree_path_from_qname
9+ from dags .tree import qname_from_tree_path
1010from jax import Array
1111
1212from lcm .grids import Grid
@@ -69,6 +69,17 @@ def get_next_state_function_for_simulation(
6969) -> NextStateSimulationFunction :
7070 """Get function that computes the next states during the simulation.
7171
72+ Builds one DAG per target regime using unqualified `next_<state>` keys, mirroring
73+ the per-target structure of {func}`get_next_state_function_for_solution`. This
74+ lets a transition function or auxiliary regime function consume another
75+ transition's `next_<state>` output via plain name resolution within the same
76+ target's DAG. The combined function returns a nested mapping keyed by target
77+ regime name, with each inner dict using unqualified `next_<state>` keys.
78+
79+ Stochastic-transition wrappers expose `key_<target>__next_<state>` and
80+ `weight_<target>__next_<state>` as external arguments so callers can pass a
81+ distinct random key and pre-computed weight per target.
82+
7283 Args:
7384 transitions: Nested mapping of target regime names to transition functions.
7485 functions: Immutable mapping of auxiliary functions of a regime.
@@ -78,26 +89,31 @@ def get_next_state_function_for_simulation(
7889
7990 Returns:
8091 Function that computes the next states. Depends on states and actions of the
81- current period, and the regime parameters ("params"). If target is "simulate",
82- the function also depends on the dictionary of random keys ("keys"), which
83- corresponds to the names of stochastic next functions .
92+ current period, and the regime parameters ("params"). The function also
93+ depends on the dictionary of random keys ("keys") for stochastic transitions.
94+ Returns `{target_regime_name: {next_<state>: array}}` .
8495
8596 """
86- flat_transitions = flatten_regime_namespace (transitions )
87-
88- # For the simulation target, we need to extend the functions dictionary with
89- # stochastic next states functions and their weights.
90- extended_transitions = _extend_transitions_for_simulation (
91- all_grids = all_grids ,
92- flat_transitions = flat_transitions ,
93- variable_info = variable_info ,
94- stochastic_transition_names = stochastic_transition_names ,
95- )
96- functions_to_concatenate = extended_transitions | dict (functions )
97+ per_target_funcs : dict [RegimeName , Callable [..., dict [str , Array ]]] = {}
98+ for target , target_transitions in transitions .items ():
99+ extended = _extend_target_transitions_for_simulation (
100+ target = target ,
101+ target_transitions = target_transitions ,
102+ all_grids = all_grids ,
103+ variable_info = variable_info ,
104+ stochastic_transition_names = stochastic_transition_names ,
105+ )
106+ per_target_funcs [target ] = concatenate_functions (
107+ functions = dict (extended ) | dict (functions ),
108+ targets = list (extended .keys ()),
109+ return_type = "dict" ,
110+ enforce_signature = False ,
111+ set_annotations = True ,
112+ )
97113
98114 return concatenate_functions (
99- functions = functions_to_concatenate ,
100- targets = list (flat_transitions .keys ()),
115+ functions = per_target_funcs ,
116+ targets = list (per_target_funcs .keys ()),
101117 return_type = "dict" ,
102118 enforce_signature = False ,
103119 set_annotations = True ,
@@ -137,64 +153,59 @@ def get_next_stochastic_weights_function(
137153 )
138154
139155
140- def _extend_transitions_for_simulation (
156+ def _extend_target_transitions_for_simulation (
141157 * ,
158+ target : RegimeName ,
159+ target_transitions : MappingProxyType [TransitionFunctionName , Callable [..., Array ]],
142160 all_grids : MappingProxyType [RegimeName , MappingProxyType [StateOrActionName , Grid ]],
143- flat_transitions : FunctionsMapping ,
144161 variable_info : pd .DataFrame ,
145162 stochastic_transition_names : frozenset [TransitionFunctionName ],
146163) -> dict [TransitionFunctionName , Callable [..., Array ]]:
147- """Extend the functions dictionary for the simulation target.
164+ """Replace stochastic transitions for one target with realisation wrappers.
165+
166+ Deterministic transitions are passed through unchanged. Stochastic transitions
167+ are replaced by wrappers that draw a realisation from a precomputed weight
168+ vector and a random key. The wrapper's external argument names use
169+ target-qualified form (`key_<target>__<next_state>`,
170+ `weight_<target>__<next_state>`) so multi-target callers can supply distinct
171+ random keys per target. The dict key keeps the unqualified `next_<state>` so
172+ other transitions or regime functions in the same target's DAG can resolve
173+ it by name.
148174
149175 Args:
176+ target: Target regime name.
177+ target_transitions: Mapping of unqualified `next_<state>` transition names
178+ to functions, restricted to one target regime.
150179 all_grids: Immutable mapping of regime names to Grid spec objects.
151- flat_transitions: Flattened mapping of transition names to functions.
152180 variable_info: Variable info of the current regime.
153181 stochastic_transition_names: Frozenset of stochastic transition function names.
154182
155183 Returns:
156- Extended functions dictionary.
184+ Extended transitions dictionary keyed by unqualified `next_<state>` names .
157185
158186 """
159187 shock_names : set [ShockName ] = set (variable_info .query ("is_shock" ).index .to_list ())
160188 flat_grids = flatten_regime_namespace (all_grids )
161- discrete_stochastic_targets = [
162- func_name
163- for func_name in flat_transitions
164- if tree_path_from_qname (func_name )[- 1 ] in stochastic_transition_names
165- and tree_path_from_qname (func_name )[- 1 ].removeprefix ("next_" ) not in shock_names
166- ]
167- continuous_stochastic_targets = [
168- func_name
169- for func_name in flat_transitions
170- if tree_path_from_qname (func_name )[- 1 ] in stochastic_transition_names
171- and tree_path_from_qname (func_name )[- 1 ].removeprefix ("next_" ) in shock_names
172- ]
173- # Handle stochastic next states functions
174- # ----------------------------------------------------------------------------------
175- # We generate stochastic next states functions that simulate the next state given
176- # a random key (think of a seed) and the weights corresponding to the labels of the
177- # stochastic variable. The weights are computed using the stochastic weight
178- # functions, which we add the to functions dict. `dags.concatenate_functions` then
179- # generates a function that computes the weights and simulates the next state in
180- # one go.
181- # ----------------------------------------------------------------------------------
182- discrete_stochastic_next = {
183- name : _create_discrete_stochastic_next_func (
184- name = name , labels = flat_grids [name .replace ("next_" , "" )].to_jax ()
185- )
186- for name in discrete_stochastic_targets
187- }
188- continuous_stochastic_next = {
189- name : _create_continuous_stochastic_next_func (name = name , flat_grids = flat_grids )
190- for name in continuous_stochastic_targets
191- }
192-
193- # Overwrite regime transitions with generated stochastic next states functions
194- # ----------------------------------------------------------------------------------
195- return (
196- dict (flat_transitions ) | discrete_stochastic_next | continuous_stochastic_next
189+ extended : dict [TransitionFunctionName , Callable [..., Array ]] = dict (
190+ target_transitions
197191 )
192+ for next_state_name in target_transitions :
193+ if next_state_name not in stochastic_transition_names :
194+ continue
195+ qname = qname_from_tree_path ((target , next_state_name ))
196+ raw_state_name = next_state_name .removeprefix ("next_" )
197+ if raw_state_name in shock_names :
198+ extended [next_state_name ] = _create_continuous_stochastic_next_func (
199+ name = qname , flat_grids = flat_grids
200+ )
201+ else :
202+ extended [next_state_name ] = _create_discrete_stochastic_next_func (
203+ name = qname ,
204+ labels = flat_grids [
205+ qname_from_tree_path ((target , raw_state_name ))
206+ ].to_jax (),
207+ )
208+ return extended
198209
199210
200211def _create_discrete_stochastic_next_func (
0 commit comments