Skip to content

Commit 102ca64

Browse files
committed
Merge branch 'develop' of https://github.com/stan-dev/cmdstanpy into develop
2 parents 1d6f6c6 + ccbc5ab commit 102ca64

8 files changed

Lines changed: 164 additions & 106 deletions

File tree

cmdstanpy/cmdstan_args.py

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ def validate(self, chains: int) -> None:
132132
)
133133
if self.step_size is not None:
134134
if isinstance(self.step_size, Real):
135-
if self.step_size < 0:
135+
if self.step_size <= 0:
136136
raise ValueError(
137137
'step_size must be > 0, found {}'.format(self.step_size)
138138
)
@@ -336,7 +336,7 @@ def validate(self, chains=None) -> None: # pylint: disable=unused-argument
336336
'init_alpha must not be set when algorithm is Newton'
337337
)
338338
if isinstance(self.init_alpha, Real):
339-
if self.init_alpha < 0:
339+
if self.init_alpha <= 0:
340340
raise ValueError('init_alpha must be greater than 0')
341341
else:
342342
raise ValueError('init_alpha must be type of float')
@@ -350,8 +350,7 @@ def validate(self, chains=None) -> None: # pylint: disable=unused-argument
350350

351351
# pylint: disable=unused-argument
352352
def compose(self, idx: int, cmd: List) -> str:
353-
"""compose command string for CmdStan for non-default arg values.
354-
"""
353+
"""compose command string for CmdStan for non-default arg values."""
355354
cmd.append('method=optimize')
356355
if self.algorithm:
357356
cmd.append('algorithm={}'.format(self.algorithm.lower()))
@@ -403,6 +402,7 @@ def __init__(
403402
elbo_samples: int = None,
404403
eta: Real = None,
405404
adapt_iter: int = None,
405+
adapt_engaged: bool = True,
406406
tol_rel_obj: Real = None,
407407
eval_elbo: int = None,
408408
output_samples: int = None,
@@ -413,6 +413,7 @@ def __init__(
413413
self.elbo_samples = elbo_samples
414414
self.eta = eta
415415
self.adapt_iter = adapt_iter
416+
self.adapt_engaged = adapt_engaged
416417
self.tol_rel_obj = tol_rel_obj
417418
self.eval_elbo = eval_elbo
418419
self.output_samples = output_samples
@@ -453,19 +454,19 @@ def validate(self, chains=None) -> None: # pylint: disable=unused-argument
453454
' found {}'.format(self.elbo_samples)
454455
)
455456
if self.eta is not None:
456-
if self.eta < 1 or not isinstance(self.eta, (Integral, Real)):
457+
if self.eta < 0 or not isinstance(self.eta, (Integral, Real)):
457458
raise ValueError(
458459
'eta must be a non-negative number,'
459460
' found {}'.format(self.eta)
460461
)
461462
if self.adapt_iter is not None:
462-
if self.adapt_iter < 1 or not isinstance(self.eta, Integral):
463+
if self.adapt_iter < 1 or not isinstance(self.adapt_iter, Integral):
463464
raise ValueError(
464465
'adapt_iter must be a positive integer,'
465466
' found {}'.format(self.adapt_iter)
466467
)
467468
if self.tol_rel_obj is not None:
468-
if self.tol_rel_obj < 1 or not isinstance(
469+
if self.tol_rel_obj <= 0 or not isinstance(
469470
self.tol_rel_obj, (Integral, Real)
470471
):
471472
raise ValueError(
@@ -503,9 +504,13 @@ def compose(self, idx: int, cmd: List) -> str:
503504
cmd.append('elbo_samples={}'.format(self.elbo_samples))
504505
if self.eta is not None:
505506
cmd.append('eta={}'.format(self.eta))
506-
if self.adapt_iter is not None:
507-
cmd.append('adapt')
508-
cmd.append('iter={}'.format(self.adapt_iter))
507+
cmd.append('adapt')
508+
if self.adapt_engaged:
509+
cmd.append('engaged=1')
510+
if self.adapt_iter is not None:
511+
cmd.append('iter={}'.format(self.adapt_iter))
512+
else:
513+
cmd.append('engaged=0')
509514
if self.tol_rel_obj is not None:
510515
cmd.append('tol_rel_obj={}'.format(self.tol_rel_obj))
511516
if self.eval_elbo is not None:
@@ -591,12 +596,12 @@ def validate(self) -> None:
591596
self._logger.info(
592597
'created output directory: %s', self.output_dir
593598
)
594-
except (RuntimeError, PermissionError):
599+
except (RuntimeError, PermissionError) as exc:
595600
raise ValueError(
596601
'invalid path for output files, no such dir: {}'.format(
597602
self.output_dir
598603
)
599-
)
604+
) from exc
600605
if not os.path.isdir(self.output_dir):
601606
raise ValueError(
602607
'specified output_dir not a directory: {}'.format(
@@ -608,11 +613,11 @@ def validate(self) -> None:
608613
with open(testpath, 'w+'):
609614
pass
610615
os.remove(testpath) # cleanup
611-
except Exception:
616+
except Exception as exc:
612617
raise ValueError(
613618
'invalid path for output files,'
614619
' cannot write to dir: {}'.format(self.output_dir)
615-
)
620+
) from exc
616621

617622
if self.seed is None:
618623
rng = RandomState()

cmdstanpy/model.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -850,13 +850,13 @@ def generate_quantities(
850850
runset._csv_files = sample_csv_files
851851
sample_fit = CmdStanMCMC(runset)
852852
sample_drawset = sample_fit.draws_as_dataframe()
853-
except ValueError as e:
853+
except ValueError as exc:
854854
raise ValueError(
855855
'Invalid mcmc_sample, error:\n\t{}\n\t'
856856
' while processing files\n\t{}'.format(
857-
repr(e), '\n\t'.join(sample_csv_files)
857+
repr(exc), '\n\t'.join(sample_csv_files)
858858
)
859-
)
859+
) from exc
860860

861861
generate_quantities_args = GenerateQuantitiesArgs(
862862
csv_files=sample_csv_files
@@ -900,10 +900,12 @@ def variational(
900900
grad_samples: int = None,
901901
elbo_samples: int = None,
902902
eta: Real = None,
903+
adapt_engaged: bool = True,
903904
adapt_iter: int = None,
904905
tol_rel_obj: Real = None,
905906
eval_elbo: int = None,
906907
output_samples: int = None,
908+
require_converged: bool = True,
907909
) -> CmdStanVB:
908910
"""
909911
Run CmdStan's variational inference algorithm to approximate
@@ -961,6 +963,8 @@ def variational(
961963
962964
:param eta: Stepsize scaling parameter.
963965
966+
:param adapt_engaged: Whether eta adaptation is engaged.
967+
964968
:param adapt_iter: Number of iterations for eta adaptation.
965969
966970
:param tol_rel_obj: Relative tolerance parameter for convergence.
@@ -970,6 +974,9 @@ def variational(
970974
:param output_samples: Number of approximate posterior output draws
971975
to save.
972976
977+
:param require_converged: Whether or not to raise an error if stan
978+
reports that "The algorithm may not have converged".
979+
973980
:return: CmdStanVB object
974981
"""
975982
variational_args = VariationalArgs(
@@ -978,6 +985,7 @@ def variational(
978985
grad_samples=grad_samples,
979986
elbo_samples=elbo_samples,
980987
eta=eta,
988+
adapt_engaged=adapt_engaged,
981989
adapt_iter=adapt_iter,
982990
tol_rel_obj=tol_rel_obj,
983991
eval_elbo=eval_elbo,
@@ -1010,7 +1018,7 @@ def variational(
10101018
errors = re.findall(pat, contents)
10111019
if len(errors) > 0:
10121020
valid = False
1013-
if not valid:
1021+
if require_converged and not valid:
10141022
raise RuntimeError('The algorithm may not have converged.')
10151023
if not runset._check_retcodes():
10161024
msg = 'Error during variational inference.\n{}'.format(

cmdstanpy/stanfit.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -261,8 +261,8 @@ def save_csvfiles(self, dir: str = None) -> None:
261261
with open(test_path, 'w'):
262262
pass
263263
os.remove(test_path) # cleanup
264-
except (IOError, OSError, PermissionError):
265-
raise Exception('cannot save to path: {}'.format(dir))
264+
except (IOError, OSError, PermissionError) as exc:
265+
raise Exception('cannot save to path: {}'.format(dir)) from exc
266266

267267
for i in range(self.chains):
268268
if not os.path.exists(self._csv_files[i]):
@@ -658,7 +658,12 @@ def summary(self, percentiles: List[int] = None) -> pd.DataFrame:
658658
do_command(cmd, logger=self.runset._logger)
659659
with open(tmp_csv_path, 'rb') as fd:
660660
summary_data = pd.read_csv(
661-
fd, delimiter=',', header=0, index_col=0, comment='#'
661+
fd,
662+
delimiter=',',
663+
header=0,
664+
index_col=0,
665+
comment='#',
666+
float_precision='high',
662667
)
663668
mask = [x == 'lp__' or not x.endswith('__') for x in summary_data.index]
664669
return summary_data[mask]
@@ -971,7 +976,11 @@ def _assemble_generated_quantities(self) -> None:
971976
drawset_list = []
972977
for chain in range(self.runset.chains):
973978
drawset_list.append(
974-
pd.read_csv(self.runset.csv_files[chain], comment='#')
979+
pd.read_csv(
980+
self.runset.csv_files[chain],
981+
comment='#',
982+
float_precision='high',
983+
)
975984
)
976985
self._generated_quantities = pd.concat(drawset_list).values
977986

cmdstanpy/utils.py

Lines changed: 24 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -387,9 +387,9 @@ def rdump(path: str, data: Dict) -> None:
387387

388388
def rload(fname: str) -> dict:
389389
"""Parse data and parameter variable values from an R dump format file.
390-
This parser only supports the subset of R dump data as described
391-
in the "Dump Data Format" section of the CmdStan manual, i.e.,
392-
scalar, vector, matrix, and array data types.
390+
This parser only supports the subset of R dump data as described
391+
in the "Dump Data Format" section of the CmdStan manual, i.e.,
392+
scalar, vector, matrix, and array data types.
393393
"""
394394
data_dict = {}
395395
with open(fname, 'r') as fd:
@@ -420,8 +420,8 @@ def rload(fname: str) -> dict:
420420

421421
def parse_rdump_value(rhs: str) -> Union[int, float, np.array]:
422422
"""Process right hand side of Rdump variable assignment statement.
423-
Value is either scalar, vector, or multi-dim structure.
424-
Use regex to capture structure values, dimensions.
423+
Value is either scalar, vector, or multi-dim structure.
424+
Use regex to capture structure values, dimensions.
425425
"""
426426
pat = re.compile(
427427
r'structure\(\s*c\((?P<vals>[^)]*)\)'
@@ -444,8 +444,8 @@ def parse_rdump_value(rhs: str) -> Union[int, float, np.array]:
444444
val = float(rhs)
445445
else:
446446
val = int(rhs)
447-
except TypeError:
448-
raise ValueError('bad value in Rdump file: {}'.format(rhs))
447+
except TypeError as exc:
448+
raise ValueError('bad value in Rdump file: {}'.format(rhs)) from exc
449449
return val
450450

451451

@@ -553,27 +553,26 @@ def scan_variational_csv(path: str) -> Dict:
553553
lineno = scan_column_names(fd, dict, lineno)
554554
line = fd.readline().lstrip(' #\t').rstrip()
555555
lineno += 1
556-
if not line.startswith('Stepsize adaptation complete.'):
557-
raise ValueError(
558-
'line {}: expecting adaptation msg, found:\n\t "{}"'.format(
559-
lineno, line
560-
)
561-
)
562-
line = fd.readline().lstrip(' #\t\n')
563-
lineno += 1
564-
if not line.startswith('eta = 1'):
565-
raise ValueError(
566-
'line {}: expecting eta = 1, found:\n\t "{}"'.format(
567-
lineno, line
556+
if line.startswith('Stepsize adaptation complete.'):
557+
line = fd.readline().lstrip(' #\t\n')
558+
lineno += 1
559+
if not line.startswith('eta'):
560+
raise ValueError(
561+
'line {}: expecting eta, found:\n\t "{}"'.format(
562+
lineno, line
563+
)
568564
)
569-
)
570-
line = fd.readline().lstrip(' #\t\n')
571-
lineno += 1
565+
line = fd.readline().lstrip(' #\t\n')
566+
lineno += 1
572567
xs = line.split(',')
573568
variational_mean = [float(x) for x in xs]
574569
dict['variational_mean'] = variational_mean
575570
dict['variational_sample'] = pd.read_csv(
576-
path, comment='#', skiprows=lineno, header=None
571+
path,
572+
comment='#',
573+
skiprows=lineno,
574+
header=None,
575+
float_precision='high',
577576
)
578577
return dict
579578

@@ -691,10 +690,10 @@ def scan_metric(fd: TextIO, config_dict: Dict, lineno: int) -> int:
691690
)
692691
try:
693692
float(stepsize.strip())
694-
except ValueError:
693+
except ValueError as exc:
695694
raise ValueError(
696695
'line {}: invalid stepsize: {}'.format(lineno, stepsize)
697-
)
696+
) from exc
698697
line = fd.readline().strip()
699698
lineno += 1
700699
if not (

0 commit comments

Comments
 (0)