Skip to content

Commit 8cb1b7e

Browse files
committed
Make polars an optional dependency
1 parent 75de3b5 commit 8cb1b7e

3 files changed

Lines changed: 101 additions & 14 deletions

File tree

cmdstanpy/utils/stancsv.py

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import math
88
import os
99
import re
10+
import warnings
1011
from dataclasses import dataclass
1112
from pathlib import Path
1213
from typing import (
@@ -26,7 +27,6 @@
2627
import numpy as np
2728
import numpy.typing as npt
2829
import pandas as pd
29-
import polars as pl
3030

3131
from cmdstanpy import _CMDSTAN_SAMPLING, _CMDSTAN_THIN, _CMDSTAN_WARMUP
3232

@@ -152,13 +152,34 @@ def csv_bytes_list_to_numpy(
152152
"""Efficiently converts a list of bytes representing whose concatenation
153153
represents a CSV file into a numpy array. Includes header specifies
154154
whether the bytes contains an initial header line."""
155-
out = (
156-
pl.read_csv(
157-
io.BytesIO(b"".join(csv_bytes_list)), has_header=includes_header
158-
)
159-
.to_numpy()
160-
.astype(np.float32)
161-
)
155+
try:
156+
import polars as pl
157+
158+
try:
159+
out = (
160+
pl.read_csv(
161+
io.BytesIO(b"".join(csv_bytes_list)),
162+
has_header=includes_header,
163+
)
164+
.to_numpy()
165+
.astype(np.float32)
166+
)
167+
if out.shape[0] == 0:
168+
raise ValueError("No data found to parse")
169+
except pl.exceptions.NoDataError as exc:
170+
raise ValueError("No data found to parse") from exc
171+
except ImportError as exc:
172+
with warnings.catch_warnings():
173+
warnings.filterwarnings("ignore")
174+
out = np.loadtxt(
175+
csv_bytes_list,
176+
skiprows=int(includes_header),
177+
delimiter=",",
178+
dtype=np.float32,
179+
)
180+
if out.shape == (0,):
181+
raise ValueError("No data found to parse") from exc
182+
162183
# Telling the type checker we know the type is correct
163184
return cast(npt.NDArray[np.float32], out)
164185

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ readme = "README.md"
99
license = { text = "BSD-3-Clause" }
1010
authors = [{ name = "Stan Dev Team" }]
1111
requires-python = ">=3.8"
12-
dependencies = ["pandas", "numpy>=1.21", "tqdm", "stanio>=0.4.0,<2.0.0", "polars>=1.8.2"]
12+
dependencies = ["pandas", "numpy>=1.21", "tqdm", "stanio>=0.4.0,<2.0.0"]
1313
dynamic = ["version"]
1414
classifiers = [
1515
"Programming Language :: Python :: 3",
@@ -40,7 +40,7 @@ packages = ["cmdstanpy", "cmdstanpy.stanfit", "cmdstanpy.utils"]
4040
"cmdstanpy" = ["py.typed"]
4141

4242
[project.optional-dependencies]
43-
all = ["xarray"]
43+
all = ["xarray", "polars>=1.8.2"]
4444
test = [
4545
"flake8",
4646
"pylint",

test/test_stancsv.py

Lines changed: 70 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
"""testing stancsv parsing"""
22

33
from typing import List
4+
from unittest import mock
45

56
import numpy as np
6-
import polars as pl
77
import pytest
88

99
from cmdstanpy.utils import stancsv
@@ -30,6 +30,28 @@ def test_csv_bytes_to_numpy_no_header():
3030
assert arr_out[0].dtype == np.float32
3131

3232

33+
def test_csv_bytes_to_numpy_no_header_no_polars():
34+
lines = [
35+
b"-6.76206,1,0.787025,1,1,0,6.81411,0.229458\n",
36+
b"-6.81411,0.983499,0.787025,1,1,0,6.8147,0.20649\n",
37+
b"-6.85511,0.994945,0.787025,2,3,0,6.85536,0.310589\n",
38+
b"-6.85511,0.812189,0.787025,1,1,0,7.16517,0.310589\n",
39+
]
40+
expected = np.array(
41+
[
42+
[-6.76206, 1, 0.787025, 1, 1, 0, 6.81411, 0.229458],
43+
[-6.81411, 0.983499, 0.787025, 1, 1, 0, 6.8147, 0.20649],
44+
[-6.85511, 0.994945, 0.787025, 2, 3, 0, 6.85536, 0.310589],
45+
[-6.85511, 0.812189, 0.787025, 1, 1, 0, 7.16517, 0.310589],
46+
],
47+
dtype=np.float32,
48+
)
49+
with mock.patch.dict("sys.modules", {"polars": None}):
50+
arr_out = stancsv.csv_bytes_list_to_numpy(lines, includes_header=False)
51+
assert np.array_equiv(arr_out, expected)
52+
assert arr_out[0].dtype == np.float32
53+
54+
3355
def test_csv_bytes_to_numpy_with_header():
3456
lines = [
3557
(
@@ -54,21 +76,65 @@ def test_csv_bytes_to_numpy_with_header():
5476
assert np.array_equiv(arr_out, expected)
5577

5678

79+
def test_csv_bytes_to_numpy_with_header_no_polars():
80+
lines = [
81+
(
82+
b"lp__,accept_stat__,stepsize__,treedepth__,"
83+
b"n_leapfrog__,divergent__,energy__,theta\n"
84+
),
85+
b"-6.76206,1,0.787025,1,1,0,6.81411,0.229458\n",
86+
b"-6.81411,0.983499,0.787025,1,1,0,6.8147,0.20649\n",
87+
b"-6.85511,0.994945,0.787025,2,3,0,6.85536,0.310589\n",
88+
b"-6.85511,0.812189,0.787025,1,1,0,7.16517,0.310589\n",
89+
]
90+
expected = np.array(
91+
[
92+
[-6.76206, 1, 0.787025, 1, 1, 0, 6.81411, 0.229458],
93+
[-6.81411, 0.983499, 0.787025, 1, 1, 0, 6.8147, 0.20649],
94+
[-6.85511, 0.994945, 0.787025, 2, 3, 0, 6.85536, 0.310589],
95+
[-6.85511, 0.812189, 0.787025, 1, 1, 0, 7.16517, 0.310589],
96+
],
97+
dtype=np.float32,
98+
)
99+
with mock.patch.dict("sys.modules", {"polars": None}):
100+
arr_out = stancsv.csv_bytes_list_to_numpy(lines, includes_header=True)
101+
assert np.array_equiv(arr_out, expected)
102+
103+
57104
def test_csv_bytes_to_numpy_empty():
58105
lines = [b""]
59-
with pytest.raises(pl.exceptions.NoDataError):
106+
with pytest.raises(ValueError):
60107
stancsv.csv_bytes_list_to_numpy(lines)
61108

62109

110+
def test_csv_bytes_to_numpy_empty_no_polars():
111+
lines = [b""]
112+
with mock.patch.dict("sys.modules", {"polars": None}):
113+
with pytest.raises(ValueError):
114+
stancsv.csv_bytes_list_to_numpy(lines)
115+
116+
63117
def test_csv_bytes_to_numpy_header_no_draws():
64118
lines = [
65119
(
66120
b"lp__,accept_stat__,stepsize__,treedepth__,"
67121
b"n_leapfrog__,divergent__,energy__,theta\n"
68122
),
69123
]
70-
arr_out = stancsv.csv_bytes_list_to_numpy(lines)
71-
assert arr_out.shape == (0, 8)
124+
with pytest.raises(ValueError):
125+
stancsv.csv_bytes_list_to_numpy(lines)
126+
127+
128+
def test_csv_bytes_to_numpy_header_no_draws_no_polars():
129+
lines = [
130+
(
131+
b"lp__,accept_stat__,stepsize__,treedepth__,"
132+
b"n_leapfrog__,divergent__,energy__,theta\n"
133+
),
134+
]
135+
with mock.patch.dict("sys.modules", {"polars": None}):
136+
with pytest.raises(ValueError):
137+
stancsv.csv_bytes_list_to_numpy(lines)
72138

73139

74140
def test_parsing_with_rules():

0 commit comments

Comments
 (0)