Skip to content

Commit a38a271

Browse files
committed
Adds unit tests and fixes inheritance for lazy writers
1 parent ef61e7c commit a38a271

File tree

2 files changed

+100
-22
lines changed

2 files changed

+100
-22
lines changed

hamilton/plugins/polars_lazyframe_extensions.py

Lines changed: 24 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@
5555

5656
from hamilton import registry
5757
from hamilton.io import utils
58-
from hamilton.io.data_adapters import DataLoader
58+
from hamilton.io.data_adapters import DataLoader, DataSaver
5959

6060
DATAFRAME_TYPE = pl.LazyFrame
6161
COLUMN_TYPE = pl.Expr
@@ -297,25 +297,25 @@ def name(cls) -> str:
297297
return "feather"
298298

299299

300-
301300
@dataclasses.dataclass
302-
class PolarsSinkParquetWriter(DataLoader):
301+
class PolarsSinkParquetWriter(DataSaver):
303302
"""
304303
Class specifically to handle writing parquet files with Polars LazyFrame.
305304
Should map to https://docs.pola.rs/api/python/stable/reference/lazyframe/api/polars.LazyFrame.sink_parquet.html
306305
"""
306+
307307
path: Union[str, Path]
308308
# kwargs:
309309
compression: str = "zstd"
310310
compression_level: Optional[int] = None
311311
statistics: bool = False
312312
row_group_size: Optional[int] = None
313313
data_page_size: Optional[int] = None
314-
314+
315315
@classmethod
316316
def applicable_types(cls) -> Collection[Type]:
317317
return [DATAFRAME_TYPE]
318-
318+
319319
def _get_writing_kwargs(self):
320320
kwargs = {}
321321
if self.compression is not None:
@@ -329,23 +329,24 @@ def _get_writing_kwargs(self):
329329
if self.data_page_size is not None:
330330
kwargs["data_page_size"] = self.data_page_size
331331
return kwargs
332-
332+
333333
def save_data(self, data: DATAFRAME_TYPE) -> Dict[str, Any]:
334334
data.sink_parquet(self.path, **self._get_writing_kwargs())
335335
metadata = utils.get_file_metadata(self.path)
336336
return metadata
337-
337+
338338
@classmethod
339339
def name(cls) -> str:
340340
return "parquet"
341341

342342

343343
@dataclasses.dataclass
344-
class PolarsSinkCSVWriter(DataLoader):
344+
class PolarsSinkCSVWriter(DataSaver):
345345
"""
346346
Class specifically to handle writing CSV files with Polars LazyFrame.
347347
Should map to https://docs.pola.rs/api/python/stable/reference/lazyframe/api/polars.LazyFrame.sink_csv.html
348348
"""
349+
349350
path: Union[str, Path]
350351
# kwargs:
351352
include_bom: bool = False
@@ -360,11 +361,11 @@ class PolarsSinkCSVWriter(DataLoader):
360361
float_precision: Optional[int] = None
361362
null_value: Optional[str] = None
362363
quote_style: Optional[str] = None
363-
364+
364365
@classmethod
365366
def applicable_types(cls) -> Collection[Type]:
366367
return [DATAFRAME_TYPE]
367-
368+
368369
def _get_writing_kwargs(self):
369370
kwargs = {}
370371
if self.include_bom is not None:
@@ -392,71 +393,72 @@ def _get_writing_kwargs(self):
392393
if self.quote_style is not None:
393394
kwargs["quote_style"] = self.quote_style
394395
return kwargs
395-
396+
396397
def save_data(self, data: DATAFRAME_TYPE) -> Dict[str, Any]:
397398
data.sink_csv(self.path, **self._get_writing_kwargs())
398399
metadata = utils.get_file_metadata(self.path)
399400
return metadata
400-
401+
401402
@classmethod
402403
def name(cls) -> str:
403404
return "csv"
404405

405406

406407
@dataclasses.dataclass
407-
class PolarsSinkIPCWriter(DataLoader):
408+
class PolarsSinkIPCWriter(DataSaver):
408409
"""
409410
Class specifically to handle writing IPC/Feather files with Polars LazyFrame.
410411
Should map to https://docs.pola.rs/api/python/stable/reference/lazyframe/api/polars.LazyFrame.sink_ipc.html
411412
"""
413+
412414
path: Union[str, Path]
413415
# kwargs:
414416
compression: Optional[str] = "zstd"
415-
417+
416418
@classmethod
417419
def applicable_types(cls) -> Collection[Type]:
418420
return [DATAFRAME_TYPE]
419-
421+
420422
def _get_writing_kwargs(self):
421423
kwargs = {}
422424
if self.compression is not None:
423425
kwargs["compression"] = self.compression
424426
return kwargs
425-
427+
426428
def save_data(self, data: DATAFRAME_TYPE) -> Dict[str, Any]:
427429
data.sink_ipc(self.path, **self._get_writing_kwargs())
428430
metadata = utils.get_file_metadata(self.path)
429431
return metadata
430-
432+
431433
@classmethod
432434
def name(cls) -> str:
433435
return "ipc"
434436

435437

436438
@dataclasses.dataclass
437-
class PolarsSinkNDJSONWriter(DataLoader):
439+
class PolarsSinkNDJSONWriter(DataSaver):
438440
"""
439441
Class specifically to handle writing NDJSON files with Polars LazyFrame.
440442
Should map to https://docs.pola.rs/api/python/stable/reference/lazyframe/api/polars.LazyFrame.sink_ndjson.html
441443
Note: Load support for NDJSON is not yet implemented.
442444
"""
445+
443446
path: Union[str, Path]
444-
447+
445448
@classmethod
446449
def applicable_types(cls) -> Collection[Type]:
447450
return [DATAFRAME_TYPE]
448-
451+
449452
def save_data(self, data: DATAFRAME_TYPE) -> Dict[str, Any]:
450453
data.sink_ndjson(self.path)
451454
metadata = utils.get_file_metadata(self.path)
452455
return metadata
453-
456+
454457
@classmethod
455458
def name(cls) -> str:
456459
return "ndjson"
457460

458461

459-
460462
def register_data_loaders():
461463
"""Function to register the data loaders for this extension."""
462464
for loader in [

tests/plugins/test_polars_lazyframe_extensions.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,10 @@
2727
PolarsScanCSVReader,
2828
PolarsScanFeatherReader,
2929
PolarsScanParquetReader,
30+
PolarsSinkCSVWriter,
31+
PolarsSinkIPCWriter,
32+
PolarsSinkNDJSONWriter,
33+
PolarsSinkParquetWriter,
3034
)
3135
from hamilton.plugins.polars_post_1_0_0_extensions import (
3236
PolarsAvroReader,
@@ -193,3 +197,75 @@ def test_polars_spreadsheet(df: pl.LazyFrame, tmp_path: pathlib.Path) -> None:
193197
assert write_kwargs["include_header"] is True
194198
assert "raise_if_empty" in read_kwargs
195199
assert read_kwargs["raise_if_empty"] is True
200+
201+
202+
def test_polars_sink_parquet(df: pl.LazyFrame, tmp_path: pathlib.Path) -> None:
203+
file = tmp_path / "test_sink.parquet"
204+
205+
writer = PolarsSinkParquetWriter(path=file)
206+
kwargs = writer._get_writing_kwargs()
207+
metadata = writer.save_data(df)
208+
209+
# Read back the data to verify it was written correctly
210+
reader = PolarsScanParquetReader(file=file)
211+
df2, _ = reader.load_data(pl.LazyFrame)
212+
213+
assert PolarsSinkParquetWriter.applicable_types() == [pl.LazyFrame]
214+
assert kwargs["compression"] == "zstd"
215+
assert kwargs["statistics"] is False
216+
assert file.exists()
217+
assert metadata["file_metadata"]["path"] == str(file)
218+
assert_frame_equal(df.collect(), df2.collect())
219+
220+
221+
def test_polars_sink_csv(df: pl.LazyFrame, tmp_path: pathlib.Path) -> None:
222+
file = tmp_path / "test_sink.csv"
223+
224+
writer = PolarsSinkCSVWriter(path=file)
225+
kwargs = writer._get_writing_kwargs()
226+
metadata = writer.save_data(df)
227+
228+
# Read back the data to verify it was written correctly
229+
reader = PolarsScanCSVReader(file=file)
230+
df2, _ = reader.load_data(pl.LazyFrame)
231+
232+
assert PolarsSinkCSVWriter.applicable_types() == [pl.LazyFrame]
233+
assert kwargs["separator"] == ","
234+
assert kwargs["include_header"] is True
235+
assert kwargs["batch_size"] == 1024
236+
assert file.exists()
237+
assert metadata["file_metadata"]["path"] == str(file)
238+
assert_frame_equal(df.collect(), df2.collect())
239+
240+
241+
def test_polars_sink_ipc(df: pl.LazyFrame, tmp_path: pathlib.Path) -> None:
242+
file = tmp_path / "test_sink.ipc"
243+
244+
writer = PolarsSinkIPCWriter(path=file)
245+
kwargs = writer._get_writing_kwargs()
246+
metadata = writer.save_data(df)
247+
248+
# Read back the data to verify it was written correctly
249+
reader = PolarsScanFeatherReader(source=file)
250+
df2, _ = reader.load_data(pl.LazyFrame)
251+
252+
assert PolarsSinkIPCWriter.applicable_types() == [pl.LazyFrame]
253+
assert kwargs["compression"] == "zstd"
254+
assert file.exists()
255+
assert metadata["file_metadata"]["path"] == str(file)
256+
assert_frame_equal(df.collect(), df2.collect())
257+
258+
259+
def test_polars_sink_ndjson(df: pl.LazyFrame, tmp_path: pathlib.Path) -> None:
260+
file = tmp_path / "test_sink.ndjson"
261+
262+
writer = PolarsSinkNDJSONWriter(path=file)
263+
metadata = writer.save_data(df)
264+
265+
# Read back the data to verify it was written correctly
266+
df2 = pl.read_ndjson(file)
267+
268+
assert PolarsSinkNDJSONWriter.applicable_types() == [pl.LazyFrame]
269+
assert file.exists()
270+
assert metadata["file_metadata"]["path"] == str(file)
271+
assert_frame_equal(df.collect(), df2)

0 commit comments

Comments
 (0)