55import io
66import json
77import math
8- import os
98import re
109import warnings
11- from dataclasses import dataclass
12- from pathlib import Path
1310from typing import (
1411 Any ,
15- Callable ,
1612 Dict ,
1713 Iterator ,
1814 List ,
3127from cmdstanpy import _CMDSTAN_SAMPLING , _CMDSTAN_THIN , _CMDSTAN_WARMUP
3228
3329
34- @dataclass
35- class ParsingRule :
36- """Defines a rule for parsing a Stan CSV file. The parser transitions
37- between two states: either in or out of a comment section. Each section
38- is associated with one of these rules. On each line within a section,
39- the action is called. If an alternative action should be taken when
40- entering a section, the entry_action should be specified."""
41-
42- action : Callable [[bytes ], None ]
43- entry_action : Optional [Callable [[bytes ], None ]] = None
44-
45-
46- @dataclass
47- class StanCsvMCMC :
48- """Class containing the parsed output of a Stan CSV file sourced
49- from the `sample` inference method."""
50-
51- config : Dict [str , Union [int , float , str ]]
52- warmup_draws : Optional [npt .NDArray [np .float32 ]]
53- step_size : Optional [float ]
54- mass_matrix : Optional [npt .NDArray [np .float32 ]]
55- sampling_draws : npt .NDArray [np .float32 ]
56- timings : Dict [str , float ]
57-
58- @classmethod
59- def from_csv (
60- cls , path : Union [os .PathLike , Path , str ], is_fixed_param : bool = False
61- ) -> "StanCsvMCMC" :
62- config_lines : List [bytes ] = []
63- warmup_lines : List [bytes ] = []
64- adaptation_lines : List [bytes ] = []
65- sampling_lines : List [bytes ] = []
66- timing_lines : List [bytes ] = []
67-
68- def add_header (line : bytes ) -> None :
69- warmup_lines .append (line )
70- sampling_lines .append (line )
71-
72- rules : Tuple [ParsingRule , ...] = tuple ()
73- if is_fixed_param :
74- rules = (
75- ParsingRule (action = config_lines .append ),
76- ParsingRule (action = sampling_lines .append ),
77- ParsingRule (action = timing_lines .append ),
78- )
79- else :
80- rules = (
81- ParsingRule (action = config_lines .append ),
82- ParsingRule (
83- entry_action = add_header , action = warmup_lines .append
84- ),
85- ParsingRule (action = adaptation_lines .append ),
86- ParsingRule (action = sampling_lines .append ),
87- ParsingRule (action = timing_lines .append ),
88- )
89- with open (path , "rb" ) as f :
90- parse_general_stan_csv_from_lines (f , rules )
91-
92- sampling_draws = csv_bytes_list_to_numpy (sampling_lines )
93- config_dict : Dict [str , Union [str , int , float ]] = {}
94- scan_config (
95- io .StringIO ("" .join (ln .decode () for ln in config_lines )),
96- config_dict ,
97- 0 ,
98- )
99- if is_fixed_param :
100- warmup_draws , step_size , mass_matrix = None , None , None
101- else :
102- warmup_draws = csv_bytes_list_to_numpy (warmup_lines )
103- step_size , mass_matrix = parse_hmc_adaptation_lines (
104- adaptation_lines
105- )
106- return cls (
107- config_dict ,
108- warmup_draws ,
109- step_size ,
110- mass_matrix ,
111- sampling_draws ,
112- parse_timing_lines (timing_lines ),
113- )
114-
115-
116- def parse_general_stan_csv_from_lines (
30+ def parse_stan_csv_comments_and_draws (
11731 lines : Iterator [bytes ],
118- rules : Tuple [ParsingRule , ...],
119- start_in_comment : bool = True ,
120- ) -> None :
121- """Parses a generalized Stan CSV structure via provided rules.
122- The core idea is that Stan CSV files can be partitioned into coherent
123- sections based on the order of commented/non-commented lines in the file.
124- The rules define actions to be taken while within a given section and
125- transitioning between them. For example, in the MCMC Stan CSV files
126- an initial commented config section is followed by uncommented lines
127- that represent the warmup draws."""
128- current_rule_idx = 0
129- in_comment = start_in_comment
32+ ) -> Tuple [List [bytes ], List [bytes ]]:
33+ """Parses lines of a Stan CSV file into comment lines and draws lines, where
34+ a draws line is just a non-commented line.
35+
36+ Returns a (comment_lines, draws_lines) tuple.
37+ """
38+ comment_lines , draws_lines = [], []
13039
13140 for line in lines :
132- is_comment = line .startswith (b"#" )
133- if is_comment == in_comment :
134- rules [current_rule_idx ].action (line )
41+ if line .startswith (b"#" ): # is comment line
42+ comment_lines .append (line )
13543 else :
136- current_rule_idx += 1
137- if len (rules ) == current_rule_idx :
138- raise IndexError (
139- "Insufficient parsing rules to parse provided csv"
140- )
141- in_comment = is_comment
142- next_entry_action = rules [current_rule_idx ].entry_action
143- if next_entry_action is not None :
144- next_entry_action (line )
145- else : # If no entry_action defined, run normal action
146- rules [current_rule_idx ].action (line )
44+ draws_lines .append (line )
45+ return comment_lines , draws_lines
14746
14847
14948def csv_bytes_list_to_numpy (
@@ -185,19 +84,23 @@ def csv_bytes_list_to_numpy(
18584
18685
18786def parse_hmc_adaptation_lines (
188- adaptation_lines : List [bytes ],
87+ comment_lines : List [bytes ],
18988) -> Tuple [float , Optional [npt .NDArray [np .float32 ]]]:
190- """Extracts step size/mass matrix information from the adaptation
191- section of the Stan CSV . If unit metric is used, the mass matrix
192- field will be None, otherwise an appropriate numpy array.
89+ """Extracts step size/mass matrix information from the Stan CSV comment
90+ lines by parsing the adaptation section . If unit metric is used, the mass
91+ matrix field will be None, otherwise an appropriate numpy array.
19392
19493 Returns a (step_size, mass_matrix) tuple"""
19594 step_size , mass_matrix = None , None
196- lines_without_comments = (ln .lstrip (b"# " ) for ln in adaptation_lines )
95+
96+ cleaned_lines = (ln .lstrip (b"# " ) for ln in comment_lines )
19797 in_matrix_block = False
19898 matrix_lines = []
199- for line in lines_without_comments :
99+ for line in cleaned_lines :
200100 if in_matrix_block and line .strip ():
101+ # Stop when we get to timing block
102+ if line .startswith (b"Elapsed Time" ):
103+ break
201104 matrix_lines .append (line )
202105 elif line .startswith (b"Step size" ):
203106 _ , ss_str = line .split (b" = " )
@@ -216,14 +119,21 @@ def parse_hmc_adaptation_lines(
216119
217120
218121def parse_timing_lines (
219- timing_lines : List [bytes ],
122+ comment_lines : List [bytes ],
220123) -> Dict [str , float ]:
221124 """Parse the timing lines into a dictionary with key corresponding
222125 to the phase, e.g. Warm-up, Sampling, Total, and value the elapsed seconds
223126 """
224127 out : Dict [str , float ] = {}
225- lines_without_comments = (ln .lstrip (b"# " ) for ln in timing_lines )
226- for line in lines_without_comments :
128+
129+ cleaned_lines = (ln .lstrip (b"# " ) for ln in comment_lines )
130+ in_timing_block = False
131+ for line in cleaned_lines :
132+ if line .startswith (b"Elapsed Time" ) and not in_timing_block :
133+ in_timing_block = True
134+
135+ if not in_timing_block :
136+ continue
227137 match = re .findall (r"([\d\.]+) seconds \((.+)\)" , str (line ))
228138 if match :
229139 seconds = float (match [0 ][0 ])
0 commit comments