@@ -260,3 +260,48 @@ def test_parsing_adaptation_lines_no_free_params():
260260 ]
261261 _ , mass_matrix = stancsv .parse_hmc_adaptation_lines (lines )
262262 assert mass_matrix is None
263+
264+
265+ def test_csv_polars_and_numpy_equiv ():
266+ lines = [
267+ b"-6.76206,1,0.787025,1,1,0,6.81411,0.229458\n " ,
268+ b"-6.81411,0.983499,0.787025,1,1,0,6.8147,0.20649\n " ,
269+ b"-6.85511,0.994945,0.787025,2,3,0,6.85536,0.310589\n " ,
270+ b"-6.85511,0.812189,0.787025,1,1,0,7.16517,0.310589\n " ,
271+ ]
272+ arr_out_polars = stancsv .csv_bytes_list_to_numpy (
273+ lines , includes_header = False
274+ )
275+ with mock .patch .dict ("sys.modules" , {"polars" : None }):
276+ arr_out_numpy = stancsv .csv_bytes_list_to_numpy (
277+ lines , includes_header = False
278+ )
279+ assert np .array_equiv (arr_out_polars , arr_out_numpy )
280+
281+
282+ def test_csv_polars_and_numpy_equiv_one_line ():
283+ lines = [
284+ b"-6.76206,1,0.787025,1,1,0,6.81411,0.229458\n " ,
285+ ]
286+ arr_out_polars = stancsv .csv_bytes_list_to_numpy (
287+ lines , includes_header = False
288+ )
289+ with mock .patch .dict ("sys.modules" , {"polars" : None }):
290+ arr_out_numpy = stancsv .csv_bytes_list_to_numpy (
291+ lines , includes_header = False
292+ )
293+ assert np .array_equiv (arr_out_polars , arr_out_numpy )
294+
295+
296+ def test_csv_polars_and_numpy_equiv_one_element ():
297+ lines = [
298+ b"-6.76206\n " ,
299+ ]
300+ arr_out_polars = stancsv .csv_bytes_list_to_numpy (
301+ lines , includes_header = False
302+ )
303+ with mock .patch .dict ("sys.modules" , {"polars" : None }):
304+ arr_out_numpy = stancsv .csv_bytes_list_to_numpy (
305+ lines , includes_header = False
306+ )
307+ assert np .array_equiv (arr_out_polars , arr_out_numpy )
0 commit comments