Skip to content

Commit 9e4340a

Browse files
committed
Refactor parsing to be function-based
1 parent 8cb1b7e commit 9e4340a

3 files changed

Lines changed: 59 additions & 192 deletions

File tree

cmdstanpy/stanfit/mcmc.py

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -441,30 +441,30 @@ def _assemble_draws(self) -> None:
441441
)
442442
self._step_size = np.empty(self.chains, dtype=float)
443443

444+
mass_matrix_per_chain = []
444445
for chain in range(self.chains):
445-
parsed_csv = stancsv.StanCsvMCMC.from_csv(
446-
self.runset.csv_files[chain],
447-
is_fixed_param=self._is_fixed_param,
446+
with open(self.runset.csv_files[chain], "rb") as f:
447+
comments, draws = stancsv.parse_stan_csv_comments_and_draws(f)
448+
449+
self._draws[:, chain, :] = stancsv.csv_bytes_list_to_numpy(draws)
450+
451+
if not self._is_fixed_param:
452+
(
453+
self._step_size[chain],
454+
mass_matrix,
455+
) = stancsv.parse_hmc_adaptation_lines(comments)
456+
mass_matrix_per_chain.append(mass_matrix)
457+
458+
if mass_matrix_per_chain[0] is not None:
459+
mm_shape = mass_matrix_per_chain[0].shape
460+
if self.metric_type == "diag_e":
461+
mm_shape = mm_shape[1:]
462+
self._metric = np.empty(
463+
(self.chains, *mm_shape),
464+
dtype=np.float32,
448465
)
449-
self._step_size[chain] = parsed_csv.step_size
450-
if self._save_warmup and parsed_csv.warmup_draws is not None:
451-
self._draws[:, chain, :] = np.concatenate(
452-
[parsed_csv.warmup_draws, parsed_csv.sampling_draws]
453-
)
454-
else:
455-
self._draws[:, chain, :] = parsed_csv.sampling_draws
456-
457-
if parsed_csv.mass_matrix is not None:
458-
if chain == 0:
459-
mm_shape = parsed_csv.mass_matrix.shape
460-
if self.metric_type == "diag_e":
461-
mm_shape = mm_shape[1:]
462-
self._metric = np.empty(
463-
(self.chains, *mm_shape),
464-
dtype=np.float32,
465-
)
466-
467-
self._metric[chain] = parsed_csv.mass_matrix
466+
for chain in range(self.chains):
467+
self._metric[chain] = mass_matrix_per_chain[chain]
468468

469469
assert self._draws is not None
470470

cmdstanpy/utils/stancsv.py

Lines changed: 32 additions & 122 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,10 @@
55
import io
66
import json
77
import math
8-
import os
98
import re
109
import warnings
11-
from dataclasses import dataclass
12-
from pathlib import Path
1310
from typing import (
1411
Any,
15-
Callable,
1612
Dict,
1713
Iterator,
1814
List,
@@ -31,119 +27,22 @@
3127
from cmdstanpy import _CMDSTAN_SAMPLING, _CMDSTAN_THIN, _CMDSTAN_WARMUP
3228

3329

34-
@dataclass
35-
class ParsingRule:
36-
"""Defines a rule for parsing a Stan CSV file. The parser transitions
37-
between two states: either in or out of a comment section. Each section
38-
is associated with one of these rules. On each line within a section,
39-
the action is called. If an alternative action should be taken when
40-
entering a section, the entry_action should be specified."""
41-
42-
action: Callable[[bytes], None]
43-
entry_action: Optional[Callable[[bytes], None]] = None
44-
45-
46-
@dataclass
47-
class StanCsvMCMC:
48-
"""Class containing the parsed output of a Stan CSV file sourced
49-
from the `sample` inference method."""
50-
51-
config: Dict[str, Union[int, float, str]]
52-
warmup_draws: Optional[npt.NDArray[np.float32]]
53-
step_size: Optional[float]
54-
mass_matrix: Optional[npt.NDArray[np.float32]]
55-
sampling_draws: npt.NDArray[np.float32]
56-
timings: Dict[str, float]
57-
58-
@classmethod
59-
def from_csv(
60-
cls, path: Union[os.PathLike, Path, str], is_fixed_param: bool = False
61-
) -> "StanCsvMCMC":
62-
config_lines: List[bytes] = []
63-
warmup_lines: List[bytes] = []
64-
adaptation_lines: List[bytes] = []
65-
sampling_lines: List[bytes] = []
66-
timing_lines: List[bytes] = []
67-
68-
def add_header(line: bytes) -> None:
69-
warmup_lines.append(line)
70-
sampling_lines.append(line)
71-
72-
rules: Tuple[ParsingRule, ...] = tuple()
73-
if is_fixed_param:
74-
rules = (
75-
ParsingRule(action=config_lines.append),
76-
ParsingRule(action=sampling_lines.append),
77-
ParsingRule(action=timing_lines.append),
78-
)
79-
else:
80-
rules = (
81-
ParsingRule(action=config_lines.append),
82-
ParsingRule(
83-
entry_action=add_header, action=warmup_lines.append
84-
),
85-
ParsingRule(action=adaptation_lines.append),
86-
ParsingRule(action=sampling_lines.append),
87-
ParsingRule(action=timing_lines.append),
88-
)
89-
with open(path, "rb") as f:
90-
parse_general_stan_csv_from_lines(f, rules)
91-
92-
sampling_draws = csv_bytes_list_to_numpy(sampling_lines)
93-
config_dict: Dict[str, Union[str, int, float]] = {}
94-
scan_config(
95-
io.StringIO("".join(ln.decode() for ln in config_lines)),
96-
config_dict,
97-
0,
98-
)
99-
if is_fixed_param:
100-
warmup_draws, step_size, mass_matrix = None, None, None
101-
else:
102-
warmup_draws = csv_bytes_list_to_numpy(warmup_lines)
103-
step_size, mass_matrix = parse_hmc_adaptation_lines(
104-
adaptation_lines
105-
)
106-
return cls(
107-
config_dict,
108-
warmup_draws,
109-
step_size,
110-
mass_matrix,
111-
sampling_draws,
112-
parse_timing_lines(timing_lines),
113-
)
114-
115-
116-
def parse_general_stan_csv_from_lines(
30+
def parse_stan_csv_comments_and_draws(
11731
lines: Iterator[bytes],
118-
rules: Tuple[ParsingRule, ...],
119-
start_in_comment: bool = True,
120-
) -> None:
121-
"""Parses a generalized Stan CSV structure via provided rules.
122-
The core idea is that Stan CSV files can be partitioned into coherent
123-
sections based on the order of commented/non-commented lines in the file.
124-
The rules define actions to be taken while within a given section and
125-
transitioning between them. For example, in the MCMC Stan CSV files
126-
an initial commented config section is followed by uncommented lines
127-
that represent the warmup draws."""
128-
current_rule_idx = 0
129-
in_comment = start_in_comment
32+
) -> Tuple[List[bytes], List[bytes]]:
33+
"""Parses lines of a Stan CSV file into comment lines and draws lines, where
34+
a draws line is just a non-commented line.
35+
36+
Returns a (comment_lines, draws_lines) tuple.
37+
"""
38+
comment_lines, draws_lines = [], []
13039

13140
for line in lines:
132-
is_comment = line.startswith(b"#")
133-
if is_comment == in_comment:
134-
rules[current_rule_idx].action(line)
41+
if line.startswith(b"#"): # is comment line
42+
comment_lines.append(line)
13543
else:
136-
current_rule_idx += 1
137-
if len(rules) == current_rule_idx:
138-
raise IndexError(
139-
"Insufficient parsing rules to parse provided csv"
140-
)
141-
in_comment = is_comment
142-
next_entry_action = rules[current_rule_idx].entry_action
143-
if next_entry_action is not None:
144-
next_entry_action(line)
145-
else: # If no entry_action defined, run normal action
146-
rules[current_rule_idx].action(line)
44+
draws_lines.append(line)
45+
return comment_lines, draws_lines
14746

14847

14948
def csv_bytes_list_to_numpy(
@@ -185,19 +84,23 @@ def csv_bytes_list_to_numpy(
18584

18685

18786
def parse_hmc_adaptation_lines(
188-
adaptation_lines: List[bytes],
87+
comment_lines: List[bytes],
18988
) -> Tuple[float, Optional[npt.NDArray[np.float32]]]:
190-
"""Extracts step size/mass matrix information from the adaptation
191-
section of the Stan CSV. If unit metric is used, the mass matrix
192-
field will be None, otherwise an appropriate numpy array.
89+
"""Extracts step size/mass matrix information from the Stan CSV comment
90+
lines by parsing the adaptation section. If unit metric is used, the mass
91+
matrix field will be None, otherwise an appropriate numpy array.
19392
19493
Returns a (step_size, mass_matrix) tuple"""
19594
step_size, mass_matrix = None, None
196-
lines_without_comments = (ln.lstrip(b"# ") for ln in adaptation_lines)
95+
96+
cleaned_lines = (ln.lstrip(b"# ") for ln in comment_lines)
19797
in_matrix_block = False
19898
matrix_lines = []
199-
for line in lines_without_comments:
99+
for line in cleaned_lines:
200100
if in_matrix_block and line.strip():
101+
# Stop when we get to timing block
102+
if line.startswith(b"Elapsed Time"):
103+
break
201104
matrix_lines.append(line)
202105
elif line.startswith(b"Step size"):
203106
_, ss_str = line.split(b" = ")
@@ -216,14 +119,21 @@ def parse_hmc_adaptation_lines(
216119

217120

218121
def parse_timing_lines(
219-
timing_lines: List[bytes],
122+
comment_lines: List[bytes],
220123
) -> Dict[str, float]:
221124
"""Parse the timing lines into a dictionary with key corresponding
222125
to the phase, e.g. Warm-up, Sampling, Total, and value the elapsed seconds
223126
"""
224127
out: Dict[str, float] = {}
225-
lines_without_comments = (ln.lstrip(b"# ") for ln in timing_lines)
226-
for line in lines_without_comments:
128+
129+
cleaned_lines = (ln.lstrip(b"# ") for ln in comment_lines)
130+
in_timing_block = False
131+
for line in cleaned_lines:
132+
if line.startswith(b"Elapsed Time") and not in_timing_block:
133+
in_timing_block = True
134+
135+
if not in_timing_block:
136+
continue
227137
match = re.findall(r"([\d\.]+) seconds \((.+)\)", str(line))
228138
if match:
229139
seconds = float(match[0][0])

test/test_stancsv.py

Lines changed: 5 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -137,57 +137,14 @@ def test_csv_bytes_to_numpy_header_no_draws_no_polars():
137137
stancsv.csv_bytes_list_to_numpy(lines)
138138

139139

140-
def test_parsing_with_rules():
140+
def test_parse_comments_and_draws():
141141
lines: List[bytes] = [b"# 1\n", b"2\n", b"3\n", b"# 4\n"]
142-
comment_lines = []
143-
non_comment_lines = []
144-
rules = (
145-
stancsv.ParsingRule(action=comment_lines.append),
146-
stancsv.ParsingRule(action=non_comment_lines.append),
147-
stancsv.ParsingRule(action=comment_lines.append),
142+
comment_lines, draws_lines = stancsv.parse_stan_csv_comments_and_draws(
143+
iter(lines)
148144
)
149-
stancsv.parse_general_stan_csv_from_lines(iter(lines), rules)
150-
assert comment_lines == [b"# 1\n", b"# 4\n"]
151-
assert non_comment_lines == [b"2\n", b"3\n"]
152-
153-
154-
def test_parsing_with_rules_not_start_in_comment():
155-
lines: List[bytes] = [b"1\n", b"2\n", b"3\n", b"# 4\n"]
156-
comment_lines = []
157-
non_comment_lines = []
158-
rules = (
159-
stancsv.ParsingRule(action=non_comment_lines.append),
160-
stancsv.ParsingRule(action=comment_lines.append),
161-
)
162-
stancsv.parse_general_stan_csv_from_lines(
163-
iter(lines), rules, start_in_comment=False
164-
)
165-
assert comment_lines == [b"# 4\n"]
166-
assert non_comment_lines == [b"1\n", b"2\n", b"3\n"]
167-
168145

169-
def test_parsing_with_rules_entry_action():
170-
lines: List[bytes] = [b"# 1\n", b"2\n", b"# 4\n"]
171-
parsed, entry = [], []
172-
rules = (
173-
stancsv.ParsingRule(action=parsed.append),
174-
stancsv.ParsingRule(action=parsed.append, entry_action=entry.append),
175-
stancsv.ParsingRule(action=parsed.append),
176-
)
177-
stancsv.parse_general_stan_csv_from_lines(iter(lines), rules)
178-
assert parsed == [b"# 1\n", b"# 4\n"]
179-
assert entry == [b"2\n"]
180-
181-
182-
def test_parsing_insufficient_rules():
183-
lines: List[bytes] = [b"# 1\n", b"2\n", b"# 4\n"]
184-
parsed = []
185-
rules = (
186-
stancsv.ParsingRule(action=parsed.append),
187-
stancsv.ParsingRule(action=parsed.append),
188-
)
189-
with pytest.raises(IndexError):
190-
stancsv.parse_general_stan_csv_from_lines(iter(lines), rules)
146+
assert comment_lines == [b"# 1\n", b"# 4\n"]
147+
assert draws_lines == [b"2\n", b"3\n"]
191148

192149

193150
def test_parsing_timing_lines():

0 commit comments

Comments
 (0)