Skip to content

Commit 83f9509

Browse files
committed
Support non-list iterable as input to DataclassWriter
reimplement and close dfurtado/dataclass-csv#61
1 parent d9f21d8 commit 83f9509

2 files changed

Lines changed: 18 additions & 14 deletions

File tree

dataclass_csv/dataclass_writer.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,25 @@
11
import csv
22
import dataclasses
3-
from typing import Type, Dict, Any, List
3+
from typing import Type, Dict, Any, List, Iterable, Generic, TypeVar
44
from .header_mapper import HeaderMapper
55

66

7-
class DataclassWriter:
7+
8+
T = TypeVar("T")
9+
10+
11+
class DataclassWriter(Generic[T]):
812
def __init__(
913
self,
1014
f: Any,
11-
data: List[Any],
12-
klass: Type[object],
15+
data: Iterable[T],
16+
klass: Type[T],
1317
dialect: str = "excel",
1418
**fmtparams: Any,
1519
):
1620
if not f:
1721
raise ValueError("The f argument is required")
1822

19-
if not isinstance(data, list):
20-
raise ValueError("Invalid 'data' argument. It must be a list")
21-
2223
if not dataclasses.is_dataclass(klass):
2324
raise ValueError("Invalid 'klass' argument. It must be a dataclass")
2425

@@ -48,7 +49,6 @@ def write(self, skip_header: bool = False):
4849
self._fieldnames = self._apply_mapping()
4950

5051
self._writer.writerow(self._fieldnames)
51-
5252
for item in self._data:
5353
if not isinstance(item, self._cls):
5454
raise TypeError(

tests/test_dataclass_writer.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -61,12 +61,16 @@ def test_invalid_file_value(tmpdir_factory):
6161
with pytest.raises(ValueError):
6262
DataclassWriter(None, users, User)
6363

64-
65-
def test_with_data_not_a_list(tmpdir_factory):
64+
def test_with_iterable(tmpdir_factory):
6665
tempfile = tmpdir_factory.mktemp("data").join("user_001.csv")
67-
68-
users = User(name="test", age=40)
66+
users_dict = {"test": User(name="test", age=40)}
6967

7068
with tempfile.open("w") as f:
71-
with pytest.raises(ValueError):
72-
DataclassWriter(f, users, User)
69+
DataclassWriter(f, users_dict.values(), User).write()
70+
71+
with tempfile.open() as f:
72+
reader = DataclassReader(f, User)
73+
saved_users = list(reader)
74+
75+
assert len(saved_users) > 0
76+
assert saved_users[0].name == users_dict["test"].name

0 commit comments

Comments
 (0)