@@ -328,9 +328,8 @@ def __init__(
328328 self ._metric_type = None
329329 self ._metric = None
330330 self ._stepsize = None
331- self ._sample = None
332- self ._warmup = None
333- self ._drawset = None
331+ self ._draws = None
332+ self ._draws_as_df = None
334333 self ._stan_variable_dims = {}
335334 self ._validate_csv = validate_csv
336335 if validate_csv :
@@ -361,17 +360,14 @@ def chain_ids(self) -> List[int]:
361360
362361 @property
363362 def num_draws (self ) -> int :
364- """Number of post-warmup draws per chain."""
363+ """Number of draws per chain."""
365364 if not self ._validate_csv and self ._draws_sampling is None :
366- return int (math .ceil (self ._iter_sampling / self ._thin ))
367- return self ._draws_sampling
368-
369- @property
370- def num_draws_warmup (self ) -> int :
371- """Number of warmup draws per chain."""
372- if not self ._validate_csv and self ._draws_warmup is None :
373- return int (math .ceil (self ._iter_warmup / self ._thin ))
374- return self ._draws_warmup
365+ return int (
366+ math .ceil (
367+ (self ._iter_sampling + self ._iter_warmup ) / self ._thin
368+ )
369+ )
370+ return self ._draws_warmup + self ._draws_sampling
375371
376372 @property
377373 def column_names (self ) -> Tuple [str , ...]:
@@ -432,8 +428,8 @@ def metric(self) -> np.ndarray:
432428 ' in order to retrieve sample metadata.'
433429 )
434430 return None
435- if self ._sample is None :
436- self ._assemble_sample ()
431+ if self ._draws is None :
432+ self ._assemble_draws ()
437433 return self ._metric
438434
439435 @property
@@ -450,39 +446,58 @@ def stepsize(self) -> np.ndarray:
450446 ' in order to retrieve sample metadata.'
451447 )
452448 return None
453- if self ._sample is None :
454- self ._assemble_sample ()
449+ if self ._draws is None :
450+ self ._assemble_draws ()
455451 return self ._stepsize
456452
457- @property
458- def sample (self ) -> np .ndarray :
453+ def draws (self , inc_warmup : bool = False ) -> np .ndarray :
459454 """
460- A 3-D numpy ndarray which contains all draws across all chains
461- arranged as (draws, chains, columns) stored column major
462- so that the values for each parameter are stored contiguously
455+ A 3-D numpy ndarray which contains all draws, from both warmup and
456+ sampling iterations, arranged as (draws, chains, columns) and stored
457+ column major, so that the values for each parameter are contiguous
463458 in memory, likewise all draws from a chain are contiguous.
459+
460+ :param inc_warmup: When ``True`` and the warmup draws are present in
461+ the output, i.e., the sampler was run with ``save_warmup=True``,
462+ then the warmup draws are included. Default value is ``False``.
464463 """
465- if not self ._validate_csv and self ._sample is None :
464+ if not self ._validate_csv and self ._draws is None :
466465 self .validate_csv_files ()
467- if self ._sample is None :
468- self ._assemble_sample ()
469- return self ._sample
466+ if self ._draws is None :
467+ self ._assemble_draws ()
468+ if not inc_warmup :
469+ if self ._save_warmup :
470+ return self ._draws [self ._draws_warmup :, :, :]
471+ return self ._draws
472+ else :
473+ if not self ._save_warmup :
474+ self ._logger .warning (
475+ 'draws from warmup iterations not available,'
476+ ' must run sampler with "save_warmup=True".'
477+ )
478+ return self ._draws
479+
480+ @property
481+ def sample (self ) -> np .ndarray :
482+ """
483+ Deprecated - use method "draws()" instead.
484+ """
485+ self ._logger .warning (
486+ 'method "sample" will be deprecated, use method "draws" instead.'
487+ )
488+ return self .draws ()
470489
471490 @property
472491 def warmup (self ) -> np .ndarray :
473492 """
474- A 3-D numpy ndarray which contains all warmup draws across all chains
475- arranged as (draws, chains, columns) stored column major
476- so that the values for each parameter are stored contiguously
477- in memory, likewise all draws from a chain are contiguous.
493+ Deprecated - use "draws(inc_warmup=True)"
478494 """
479- if not self ._save_warmup :
480- return None
481- if not self ._validate_csv and self ._sample is None :
482- self .validate_csv_files ()
483- if self ._sample is None :
484- self ._assemble_sample ()
485- return self ._warmup
495+ self ._logger .warning (
496+ 'method "warmup" has been deprecated, instead use method'
497+ ' "draws(inc_warmup=True)", returning draws from both'
498+ ' warmup and sampling iterations.'
499+ )
500+ return self .draws (inc_warmup = True )
486501
487502 def validate_csv_files (self ) -> None :
488503 """
@@ -535,28 +550,21 @@ def validate_csv_files(self) -> None:
535550 self ._metric_type = dzero .get ('metric' )
536551 self ._stan_variable_dims = parse_var_dims (dzero ['column_names' ])
537552
538- def _assemble_sample (self ) -> None :
553+ def _assemble_draws (self ) -> None :
539554 """
540555 Allocates and populates the stepsize, metric, and sample arrays
541556 by parsing the validated stan_csv files.
542557 """
543- if self ._sample is not None :
558+ if self ._draws is not None :
544559 return
545- self ._sample = np .empty (
546- (self ._draws_sampling , self .runset .chains , len (self ._column_names )),
560+ num_draws = self ._draws_sampling
561+ if self ._save_warmup :
562+ num_draws += self ._draws_warmup
563+ self ._draws = np .empty (
564+ (num_draws , self .runset .chains , len (self ._column_names )),
547565 dtype = float ,
548566 order = 'F' ,
549567 )
550- if self ._save_warmup :
551- self ._warmup = np .empty (
552- (
553- self ._draws_warmup ,
554- self .runset .chains ,
555- len (self ._column_names ),
556- ),
557- dtype = float ,
558- order = 'F' ,
559- )
560568 if not self ._is_fixed_param :
561569 self ._stepsize = np .empty (self .runset .chains , dtype = float )
562570 if self ._metric_type == 'diag_e' :
@@ -580,7 +588,7 @@ def _assemble_sample(self) -> None:
580588 for i in range (self ._draws_warmup ):
581589 line = fd .readline ().strip ()
582590 xs = line .split (',' )
583- self ._warmup [i , chain , :] = [float (x ) for x in xs ]
591+ self ._draws [i , chain , :] = [float (x ) for x in xs ]
584592 # read to adaptation msg
585593 if line != '# Adaptation terminated' :
586594 while line != '# Adaptation terminated' :
@@ -600,10 +608,10 @@ def _assemble_sample(self) -> None:
600608 xs = line .split (',' )
601609 self ._metric [chain , i , :] = [float (x ) for x in xs ]
602610 # process draws
603- for i in range (self ._draws_sampling ):
611+ for i in range (self ._draws_warmup , num_draws ):
604612 line = fd .readline ().strip ()
605613 xs = line .split (',' )
606- self ._sample [i , chain , :] = [float (x ) for x in xs ]
614+ self ._draws [i , chain , :] = [float (x ) for x in xs ]
607615
608616 def summary (self , percentiles : List [int ] = None ) -> pd .DataFrame :
609617 """
@@ -676,39 +684,47 @@ def diagnose(self) -> str:
676684 self .runset ._logger .info (result )
677685 return result
678686
679- def get_drawset (self , params : List [str ] = None ) -> pd .DataFrame :
687+ def draws_as_dataframe (
688+ self , params : List [str ] = None , inc_warmup : bool = False
689+ ) -> pd .DataFrame :
680690 """
681- Returns the assembled sample as a pandas DataFrame consisting of
691+ Returns the assembled draws as a pandas DataFrame consisting of
682692 one column per parameter and one row per draw.
683693
684694 :param params: list of model parameter names.
695+
696+ :param inc_warmup: When ``True`` and the warmup draws are present in
697+ the output, i.e., the sampler was run with ``save_warmup=True``,
698+ then the warmup draws are included. Default value is ``False``.
685699 """
686700 pnames_base = [name .split ('.' )[0 ] for name in self .column_names ]
687701 if params is not None :
688702 for param in params :
689703 if not (param in self ._column_names or param in pnames_base ):
690704 raise ValueError ('unknown parameter: {}' .format (param ))
691- self ._assemble_sample ()
692- if self ._drawset is None :
705+ self ._assemble_draws ()
706+ if self ._draws_as_df is None :
693707 # pylint: disable=redundant-keyword-arg
694- data = self .sample .reshape (
708+ data = self .draws ( inc_warmup = inc_warmup ) .reshape (
695709 (self .num_draws * self .runset .chains ),
696710 len (self .column_names ),
697711 order = 'A' ,
698712 )
699- self ._drawset = pd .DataFrame (data = data , columns = self .column_names )
713+ self ._draws_as_df = pd .DataFrame (
714+ data = data , columns = self .column_names
715+ )
700716 if params is None :
701- return self ._drawset
717+ return self ._draws_as_df
702718 mask = []
703719 params = set (params )
704720 for name in self .column_names :
705721 if any (item in params for item in (name , name .split ('.' )[0 ])):
706722 mask .append (name )
707- return self ._drawset [mask ]
723+ return self ._draws_as_df [mask ]
708724
709725 def stan_variable (self , name : str ) -> np .ndarray :
710726 """
711- Return a new ndarray which contains the set of draws
727+ Return a new ndarray which contains the set of post-warmup draws
712728 for the named Stan program variable. Flattens the chains.
713729 Underlyingly draws are in chain order, i.e., for a sample
714730 consisting of N chains of M draws each, the first M array
@@ -728,12 +744,14 @@ def stan_variable(self, name: str) -> np.ndarray:
728744 """
729745 if name not in self ._stan_variable_dims :
730746 raise ValueError ('unknown name: {}' .format (name ))
731- self ._assemble_sample ()
747+ self ._assemble_draws ()
732748 dim0 = self .num_draws * self .runset .chains
733749 dims = self ._stan_variable_dims [name ]
734750 if dims == 1 :
735751 idx = self .column_names .index (name )
736- return self .sample [:, :, idx ].reshape ((dim0 ,), order = 'A' )
752+ return self ._draws [self ._draws_warmup :, :, idx ].reshape (
753+ (dim0 ,), order = 'A'
754+ )
737755 else :
738756 idxs = [
739757 x [0 ]
@@ -742,9 +760,9 @@ def stan_variable(self, name: str) -> np.ndarray:
742760 ]
743761 var_dims = [dim0 ]
744762 var_dims .extend (dims )
745- return self .sample [:, :, idxs [ 0 ] : idxs [ - 1 ] + 1 ]. reshape (
746- tuple ( var_dims ), order = 'A'
747- )
763+ return self ._draws [
764+ self . _draws_warmup :, :, idxs [ 0 ] : idxs [ - 1 ] + 1
765+ ]. reshape ( tuple ( var_dims ), order = 'A' )
748766
749767 def stan_variables (self ) -> Dict :
750768 """
@@ -762,10 +780,10 @@ def sampler_diagnostics(self) -> Dict:
762780 column name to draws X chains X 1 ndarray.
763781 """
764782 result = {}
765- self ._assemble_sample ()
783+ self ._assemble_draws ()
766784 diag_names = [x for x in self .column_names if x .endswith ('__' )]
767785 for idx , value in enumerate (diag_names ):
768- result [value ] = self .sample [:, :, idx ]
786+ result [value ] = self ._draws [:, :, idx ]
769787 return result
770788
771789 def save_csvfiles (self , dir : str = None ) -> None :
0 commit comments