@@ -105,6 +105,46 @@ def __init__(
105105 if not self ._is_fixed_param :
106106 self ._check_sampler_diagnostics ()
107107
108+ def create_inits (
109+ self , seed : Optional [int ] = None , chains : int = 4
110+ ) -> Union [List [Dict [str , np .ndarray ]], Dict [str , np .ndarray ]]:
111+ """
112+ Create initial values for the parameters of the model by
113+ randomly selecting draws from the MCMC samples. If the samples
114+ contain draws from multiple chains, each draw will be from
115+ a different chain, if possible. Otherwise the chain is randomly
116+ selected.
117+
118+ :param seed: Used for random selection, defaults to None
119+ :param chains: Number of initial values to return, defaults to 4
120+ :return: The initial values for the parameters of the model.
121+
122+ If ``chains`` is 1, a dictionary is returned, otherwise a list
123+ of dictionaries is returned, in the format expected for the
124+ ``inits`` argument of :meth:`CmdStanModel.sample`.
125+ """
126+ self ._assemble_draws ()
127+ rng = np .random .default_rng (seed )
128+ n_draws , n_chains = self ._draws .shape [:2 ]
129+ draw_idxs = rng .choice (n_draws , size = chains , replace = False )
130+ chain_idxs = rng .choice (
131+ n_chains , size = chains , replace = (n_chains <= chains )
132+ )
133+ if chains == 1 :
134+ draw = self ._draws [draw_idxs [0 ], chain_idxs [0 ]]
135+ return {
136+ name : var .extract_reshape (draw )
137+ for name , var in self ._metadata .stan_vars .items ()
138+ }
139+ else :
140+ return [
141+ {
142+ name : var .extract_reshape (self ._draws [d , i ])
143+ for name , var in self ._metadata .stan_vars .items ()
144+ }
145+ for d , i in zip (draw_idxs , chain_idxs )
146+ ]
147+
108148 def __repr__ (self ) -> str :
109149 repr = 'CmdStanMCMC: model={} chains={}{}' .format (
110150 self .runset .model ,
@@ -685,7 +725,7 @@ def draws_xr(
685725 )
686726 if inc_warmup and not self ._save_warmup :
687727 get_logger ().warning (
688- " Draws from warmup iterations not available,"
728+ ' Draws from warmup iterations not available,'
689729 ' must run sampler with "save_warmup=True".'
690730 )
691731 if vars is None :
0 commit comments