Skip to content

Commit bba3bde

Browse files
committed
Add numpy/polars equiv testing
1 parent d1838f1 commit bba3bde

1 file changed

Lines changed: 45 additions & 0 deletions

File tree

test/test_stancsv.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,3 +260,48 @@ def test_parsing_adaptation_lines_no_free_params():
260260
]
261261
_, mass_matrix = stancsv.parse_hmc_adaptation_lines(lines)
262262
assert mass_matrix is None
263+
264+
265+
def test_csv_polars_and_numpy_equiv():
266+
lines = [
267+
b"-6.76206,1,0.787025,1,1,0,6.81411,0.229458\n",
268+
b"-6.81411,0.983499,0.787025,1,1,0,6.8147,0.20649\n",
269+
b"-6.85511,0.994945,0.787025,2,3,0,6.85536,0.310589\n",
270+
b"-6.85511,0.812189,0.787025,1,1,0,7.16517,0.310589\n",
271+
]
272+
arr_out_polars = stancsv.csv_bytes_list_to_numpy(
273+
lines, includes_header=False
274+
)
275+
with mock.patch.dict("sys.modules", {"polars": None}):
276+
arr_out_numpy = stancsv.csv_bytes_list_to_numpy(
277+
lines, includes_header=False
278+
)
279+
assert np.array_equiv(arr_out_polars, arr_out_numpy)
280+
281+
282+
def test_csv_polars_and_numpy_equiv_one_line():
283+
lines = [
284+
b"-6.76206,1,0.787025,1,1,0,6.81411,0.229458\n",
285+
]
286+
arr_out_polars = stancsv.csv_bytes_list_to_numpy(
287+
lines, includes_header=False
288+
)
289+
with mock.patch.dict("sys.modules", {"polars": None}):
290+
arr_out_numpy = stancsv.csv_bytes_list_to_numpy(
291+
lines, includes_header=False
292+
)
293+
assert np.array_equiv(arr_out_polars, arr_out_numpy)
294+
295+
296+
def test_csv_polars_and_numpy_equiv_one_element():
297+
lines = [
298+
b"-6.76206\n",
299+
]
300+
arr_out_polars = stancsv.csv_bytes_list_to_numpy(
301+
lines, includes_header=False
302+
)
303+
with mock.patch.dict("sys.modules", {"polars": None}):
304+
arr_out_numpy = stancsv.csv_bytes_list_to_numpy(
305+
lines, includes_header=False
306+
)
307+
assert np.array_equiv(arr_out_polars, arr_out_numpy)

0 commit comments

Comments
 (0)