@@ -702,7 +702,7 @@ def draws_as_dataframe(
702702 the output, i.e., the sampler was run with ``save_warmup=True``,
703703 then the warmup draws are included. Default value is ``False``.
704704 """
705- pnames_base = [name .split ('. ' )[0 ] for name in self .column_names ]
705+ pnames_base = [name .split ('[ ' )[0 ] for name in self .column_names ]
706706 if params is not None :
707707 for param in params :
708708 if not (param in self ._column_names or param in pnames_base ):
@@ -723,51 +723,47 @@ def draws_as_dataframe(
723723 mask = []
724724 params = set (params )
725725 for name in self .column_names :
726- if any (item in params for item in (name , name .split ('. ' )[0 ])):
726+ if any (item in params for item in (name , name .split ('[ ' )[0 ])):
727727 mask .append (name )
728728 return self ._draws_as_df [mask ]
729729
730- def stan_variable (self , name : str ) -> np . ndarray :
730+ def stan_variable (self , name : str ) -> pd . DataFrame :
731731 """
732- Return a new ndarray which contains the set of post-warmup draws
732+ Return a new DataFrame which contains the set of post-warmup draws
733733 for the named Stan program variable. Flattens the chains.
734734 Underlyingly draws are in chain order, i.e., for a sample
735735 consisting of N chains of M draws each, the first M array
736736 elements are from chain 1, the next M are from chain 2,
737737 and the last M elements are from chain N.
738738
739- * If the variable is a scalar variable, this returns a 1-d array,
740- length( draws X chains).
741- * If the variable is a vector, this is a 2-d array,
742- shape ( draws X chains, len(vector))
743- * If the variable is a matrix, this is a 3-d array,
744- shape ( draws X chains, matrix nrows, matrix ncols ).
745- * If the variable is an array with N dimensions, this is an N+1-d array,
746- shape ( draws X chains, size(dim 1), ... size(dim N)).
739+ * If the variable is a scalar variable, the shape of the DataFrame is
740+ ( draws X chains, 1 ).
741+ * If the variable is a vector, the shape of the DataFrame is
742+ ( draws X chains, len(vector))
743+ * If the variable is a matrix, the shape of the DataFrame is
744+ ( draws X chains, size(dim 1) X size(dim 2) )
745+ * If the variable is an array with N dimensions, the shape of the
746+ DataFrame is ( draws X chains, size(dim 1) X ... X size(dim N))
747747
748748 :param name: variable name
749749 """
750750 if name not in self ._stan_variable_dims :
751751 raise ValueError ('unknown name: {}' .format (name ))
752752 self ._assemble_draws ()
753753 dim0 = self .num_draws * self .runset .chains
754- dims = self ._stan_variable_dims [name ]
755- if dims == 1 :
756- idx = self .column_names .index (name )
757- return self ._draws [self ._draws_warmup :, :, idx ].reshape (
758- (dim0 ,), order = 'A'
759- )
760- else :
761- idxs = [
762- x [0 ]
763- for x in enumerate (self .column_names )
764- if x [1 ].startswith (name + '.' )
765- ]
766- var_dims = [dim0 ]
767- var_dims .extend (dims )
768- return self ._draws [
769- self ._draws_warmup :, :, idxs [0 ] : idxs [- 1 ] + 1
770- ].reshape (tuple (var_dims ), order = 'A' )
754+ dims = np .prod (self ._stan_variable_dims [name ])
755+ pattern = r'^{}(\[[\d,]+\])?$' .format (name )
756+ names , idxs = [], []
757+ for i , column_name in enumerate (self .column_names ):
758+ if re .search (pattern , column_name ):
759+ names .append (column_name )
760+ idxs .append (i )
761+ return pd .DataFrame (
762+ self ._draws [
763+ self ._draws_warmup :, :, idxs
764+ ].reshape ((dim0 , dims ), order = 'A' ),
765+ columns = names
766+ )
771767
772768 def stan_variables (self ) -> Dict :
773769 """
0 commit comments