Skip to content

Commit dba7b51

Browse files
committed
Merge branch 'develop' of https://github.com/stan-dev/cmdstanpy into develop
2 parents d8cf8b8 + 798e178 commit dba7b51

File tree

7 files changed

+185
-177
lines changed

7 files changed

+185
-177
lines changed

cmdstanpy/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ def __repr__(self) -> str:
161161
repr = 'CmdStanModel: name={}'.format(self._name)
162162
repr = '{}\n\t stan_file={}'.format(repr, self._stan_file)
163163
repr = '{}\n\t exe_file={}'.format(repr, self._exe_file)
164-
repr = '{}\n\t compiler_optons={}'.format(repr, self._compiler_options)
164+
repr = '{}\n\t compiler_options={}'.format(repr, self._compiler_options)
165165
return repr
166166

167167
@property

cmdstanpy/stanfit.py

Lines changed: 25 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -702,7 +702,7 @@ def draws_as_dataframe(
702702
the output, i.e., the sampler was run with ``save_warmup=True``,
703703
then the warmup draws are included. Default value is ``False``.
704704
"""
705-
pnames_base = [name.split('.')[0] for name in self.column_names]
705+
pnames_base = [name.split('[')[0] for name in self.column_names]
706706
if params is not None:
707707
for param in params:
708708
if not (param in self._column_names or param in pnames_base):
@@ -723,51 +723,47 @@ def draws_as_dataframe(
723723
mask = []
724724
params = set(params)
725725
for name in self.column_names:
726-
if any(item in params for item in (name, name.split('.')[0])):
726+
if any(item in params for item in (name, name.split('[')[0])):
727727
mask.append(name)
728728
return self._draws_as_df[mask]
729729

730-
def stan_variable(self, name: str) -> np.ndarray:
730+
def stan_variable(self, name: str) -> pd.DataFrame:
731731
"""
732-
Return a new ndarray which contains the set of post-warmup draws
732+
Return a new DataFrame which contains the set of post-warmup draws
733733
for the named Stan program variable. Flattens the chains.
734734
Underlyingly draws are in chain order, i.e., for a sample
735735
consisting of N chains of M draws each, the first M array
736736
elements are from chain 1, the next M are from chain 2,
737737
and the last M elements are from chain N.
738738
739-
* If the variable is a scalar variable, this returns a 1-d array,
740-
length(draws X chains).
741-
* If the variable is a vector, this is a 2-d array,
742-
shape ( draws X chains, len(vector))
743-
* If the variable is a matrix, this is a 3-d array,
744-
shape ( draws X chains, matrix nrows, matrix ncols ).
745-
* If the variable is an array with N dimensions, this is an N+1-d array,
746-
shape ( draws X chains, size(dim 1), ... size(dim N)).
739+
* If the variable is a scalar variable, the shape of the DataFrame is
740+
( draws X chains, 1).
741+
* If the variable is a vector, the shape of the DataFrame is
742+
( draws X chains, len(vector))
743+
* If the variable is a matrix, the shape of the DataFrame is
744+
( draws X chains, size(dim 1) X size(dim 2) )
745+
* If the variable is an array with N dimensions, the shape of the
746+
DataFrame is ( draws X chains, size(dim 1) X ... X size(dim N))
747747
748748
:param name: variable name
749749
"""
750750
if name not in self._stan_variable_dims:
751751
raise ValueError('unknown name: {}'.format(name))
752752
self._assemble_draws()
753753
dim0 = self.num_draws * self.runset.chains
754-
dims = self._stan_variable_dims[name]
755-
if dims == 1:
756-
idx = self.column_names.index(name)
757-
return self._draws[self._draws_warmup :, :, idx].reshape(
758-
(dim0,), order='A'
759-
)
760-
else:
761-
idxs = [
762-
x[0]
763-
for x in enumerate(self.column_names)
764-
if x[1].startswith(name + '.')
765-
]
766-
var_dims = [dim0]
767-
var_dims.extend(dims)
768-
return self._draws[
769-
self._draws_warmup :, :, idxs[0] : idxs[-1] + 1
770-
].reshape(tuple(var_dims), order='A')
754+
dims = np.prod(self._stan_variable_dims[name])
755+
pattern = r'^{}(\[[\d,]+\])?$'.format(name)
756+
names, idxs = [], []
757+
for i, column_name in enumerate(self.column_names):
758+
if re.search(pattern, column_name):
759+
names.append(column_name)
760+
idxs.append(i)
761+
return pd.DataFrame(
762+
self._draws[
763+
self._draws_warmup:, :, idxs
764+
].reshape((dim0, dims), order='A'),
765+
columns=names
766+
)
771767

772768
def stan_variables(self) -> Dict:
773769
"""

cmdstanpy/utils.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -636,11 +636,19 @@ def scan_column_names(fd: TextIO, config_dict: Dict, lineno: int) -> int:
636636
line = fd.readline().strip()
637637
lineno += 1
638638
names = line.split(',')
639-
config_dict['column_names'] = tuple(names)
639+
config_dict['column_names'] = tuple(_rename_columns(names))
640640
config_dict['num_params'] = len(names) - 1
641641
return lineno
642642

643643

644+
def _rename_columns(names: List) -> List:
645+
names = [
646+
re.sub(r',([\d,]+)$', r'[\1]', column.replace('.', ','))
647+
for column in names
648+
]
649+
return names
650+
651+
644652
def parse_var_dims(names: Tuple[str, ...]) -> Dict:
645653
"""
646654
Use Stan CSV file column names to get variable names, dimensions.
@@ -653,14 +661,14 @@ def parse_var_dims(names: Tuple[str, ...]) -> Dict:
653661
while idx < len(names):
654662
if names[idx].endswith('__'):
655663
pass
656-
elif '.' not in names[idx]:
664+
elif '[' not in names[idx]:
657665
vars_dict[names[idx]] = 1
658666
else:
659-
vs = names[idx].split('.')
660-
if idx < len(names) - 1 and names[idx + 1].split('.')[0] == vs[0]:
667+
vs = names[idx].split('[')
668+
if idx < len(names) - 1 and names[idx + 1].split('[')[0] == vs[0]:
661669
idx += 1
662670
continue
663-
dims = [int(vs[x]) for x in range(1, len(vs))]
671+
dims = [int(x) for x in vs[1][:-1].split(',')]
664672
vars_dict[vs[0]] = tuple(dims)
665673
idx += 1
666674
return vars_dict

test/test_generate_quantities.py

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -37,16 +37,16 @@ def test_gen_quantities_csv_files(self):
3737
csv_file = bern_gqs.runset.csv_files[i]
3838
self.assertTrue(os.path.exists(csv_file))
3939
column_names = [
40-
'y_rep.1',
41-
'y_rep.2',
42-
'y_rep.3',
43-
'y_rep.4',
44-
'y_rep.5',
45-
'y_rep.6',
46-
'y_rep.7',
47-
'y_rep.8',
48-
'y_rep.9',
49-
'y_rep.10',
40+
'y_rep[1]',
41+
'y_rep[2]',
42+
'y_rep[3]',
43+
'y_rep[4]',
44+
'y_rep[5]',
45+
'y_rep[6]',
46+
'y_rep[7]',
47+
'y_rep[8]',
48+
'y_rep[9]',
49+
'y_rep[10]',
5050
]
5151
self.assertEqual(bern_gqs.column_names, tuple(column_names))
5252
self.assertEqual(
@@ -104,16 +104,16 @@ def test_gen_quanties_mcmc_sample(self):
104104
csv_file = bern_gqs.runset.csv_files[i]
105105
self.assertTrue(os.path.exists(csv_file))
106106
column_names = [
107-
'y_rep.1',
108-
'y_rep.2',
109-
'y_rep.3',
110-
'y_rep.4',
111-
'y_rep.5',
112-
'y_rep.6',
113-
'y_rep.7',
114-
'y_rep.8',
115-
'y_rep.9',
116-
'y_rep.10',
107+
'y_rep[1]',
108+
'y_rep[2]',
109+
'y_rep[3]',
110+
'y_rep[4]',
111+
'y_rep[5]',
112+
'y_rep[6]',
113+
'y_rep[7]',
114+
'y_rep[8]',
115+
'y_rep[9]',
116+
'y_rep[10]',
117117
]
118118
self.assertEqual(bern_gqs.column_names, tuple(column_names))
119119
self.assertEqual(

0 commit comments

Comments
 (0)