Skip to content

Commit a4caa20

Browse files
Chowdhury-Anikskrawcz
authored andcommitted
Add data source sinks for Polars LazyFrame (#791)
This commit adds four data sink methods for Polars LazyFrame: - sink_parquet: Write LazyFrame to Parquet format - sink_csv: Write LazyFrame to CSV format - sink_ipc: Write LazyFrame to IPC/Feather format - sink_ndjson: Write LazyFrame to NDJSON format These sinks allow users to write LazyFrames directly without needing to call .collect() first, improving performance for large datasets. Fixes #791
1 parent f3ff012 commit a4caa20

1 file changed

Lines changed: 174 additions & 0 deletions

File tree

hamilton/plugins/polars_lazyframe_extensions.py

Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,166 @@ def name(cls) -> str:
297297
return "feather"
298298

299299

300+
301+
@dataclasses.dataclass
302+
class PolarsSinkParquetWriter(DataLoader):
303+
"""
304+
Class specifically to handle writing parquet files with Polars LazyFrame.
305+
Should map to https://docs.pola.rs/api/python/stable/reference/lazyframe/api/polars.LazyFrame.sink_parquet.html
306+
"""
307+
path: Union[str, Path]
308+
# kwargs:
309+
compression: str = "zstd"
310+
compression_level: Optional[int] = None
311+
statistics: bool = False
312+
row_group_size: Optional[int] = None
313+
data_page_size: Optional[int] = None
314+
315+
@classmethod
316+
def applicable_types(cls) -> Collection[Type]:
317+
return [DATAFRAME_TYPE]
318+
319+
def _get_writing_kwargs(self):
320+
kwargs = {}
321+
if self.compression is not None:
322+
kwargs["compression"] = self.compression
323+
if self.compression_level is not None:
324+
kwargs["compression_level"] = self.compression_level
325+
if self.statistics is not None:
326+
kwargs["statistics"] = self.statistics
327+
if self.row_group_size is not None:
328+
kwargs["row_group_size"] = self.row_group_size
329+
if self.data_page_size is not None:
330+
kwargs["data_page_size"] = self.data_page_size
331+
return kwargs
332+
333+
def save_data(self, data: DATAFRAME_TYPE) -> Dict[str, Any]:
334+
data.sink_parquet(self.path, **self._get_writing_kwargs())
335+
metadata = utils.get_file_metadata(self.path)
336+
return metadata
337+
338+
@classmethod
339+
def name(cls) -> str:
340+
return "parquet"
341+
342+
343+
@dataclasses.dataclass
344+
class PolarsSinkCSVWriter(DataLoader):
345+
"""
346+
Class specifically to handle writing CSV files with Polars LazyFrame.
347+
Should map to https://docs.pola.rs/api/python/stable/reference/lazyframe/api/polars.LazyFrame.sink_csv.html
348+
"""
349+
path: Union[str, Path]
350+
# kwargs:
351+
include_bom: bool = False
352+
include_header: bool = True
353+
separator: str = ","
354+
line_terminator: str = "\n"
355+
quote_char: str = '"'
356+
batch_size: int = 1024
357+
datetime_format: Optional[str] = None
358+
date_format: Optional[str] = None
359+
time_format: Optional[str] = None
360+
float_precision: Optional[int] = None
361+
null_value: Optional[str] = None
362+
quote_style: Optional[str] = None
363+
364+
@classmethod
365+
def applicable_types(cls) -> Collection[Type]:
366+
return [DATAFRAME_TYPE]
367+
368+
def _get_writing_kwargs(self):
369+
kwargs = {}
370+
if self.include_bom is not None:
371+
kwargs["include_bom"] = self.include_bom
372+
if self.include_header is not None:
373+
kwargs["include_header"] = self.include_header
374+
if self.separator is not None:
375+
kwargs["separator"] = self.separator
376+
if self.line_terminator is not None:
377+
kwargs["line_terminator"] = self.line_terminator
378+
if self.quote_char is not None:
379+
kwargs["quote_char"] = self.quote_char
380+
if self.batch_size is not None:
381+
kwargs["batch_size"] = self.batch_size
382+
if self.datetime_format is not None:
383+
kwargs["datetime_format"] = self.datetime_format
384+
if self.date_format is not None:
385+
kwargs["date_format"] = self.date_format
386+
if self.time_format is not None:
387+
kwargs["time_format"] = self.time_format
388+
if self.float_precision is not None:
389+
kwargs["float_precision"] = self.float_precision
390+
if self.null_value is not None:
391+
kwargs["null_value"] = self.null_value
392+
if self.quote_style is not None:
393+
kwargs["quote_style"] = self.quote_style
394+
return kwargs
395+
396+
def save_data(self, data: DATAFRAME_TYPE) -> Dict[str, Any]:
397+
data.sink_csv(self.path, **self._get_writing_kwargs())
398+
metadata = utils.get_file_metadata(self.path)
399+
return metadata
400+
401+
@classmethod
402+
def name(cls) -> str:
403+
return "csv"
404+
405+
406+
@dataclasses.dataclass
407+
class PolarsSinkIPCWriter(DataLoader):
408+
"""
409+
Class specifically to handle writing IPC/Feather files with Polars LazyFrame.
410+
Should map to https://docs.pola.rs/api/python/stable/reference/lazyframe/api/polars.LazyFrame.sink_ipc.html
411+
"""
412+
path: Union[str, Path]
413+
# kwargs:
414+
compression: Optional[str] = "zstd"
415+
416+
@classmethod
417+
def applicable_types(cls) -> Collection[Type]:
418+
return [DATAFRAME_TYPE]
419+
420+
def _get_writing_kwargs(self):
421+
kwargs = {}
422+
if self.compression is not None:
423+
kwargs["compression"] = self.compression
424+
return kwargs
425+
426+
def save_data(self, data: DATAFRAME_TYPE) -> Dict[str, Any]:
427+
data.sink_ipc(self.path, **self._get_writing_kwargs())
428+
metadata = utils.get_file_metadata(self.path)
429+
return metadata
430+
431+
@classmethod
432+
def name(cls) -> str:
433+
return "ipc"
434+
435+
436+
@dataclasses.dataclass
437+
class PolarsSinkNDJSONWriter(DataLoader):
438+
"""
439+
Class specifically to handle writing NDJSON files with Polars LazyFrame.
440+
Should map to https://docs.pola.rs/api/python/stable/reference/lazyframe/api/polars.LazyFrame.sink_ndjson.html
441+
Note: Load support for NDJSON is not yet implemented.
442+
"""
443+
path: Union[str, Path]
444+
445+
@classmethod
446+
def applicable_types(cls) -> Collection[Type]:
447+
return [DATAFRAME_TYPE]
448+
449+
def save_data(self, data: DATAFRAME_TYPE) -> Dict[str, Any]:
450+
data.sink_ndjson(self.path)
451+
metadata = utils.get_file_metadata(self.path)
452+
return metadata
453+
454+
@classmethod
455+
def name(cls) -> str:
456+
return "ndjson"
457+
458+
459+
300460
def register_data_loaders():
301461
"""Function to register the data loaders for this extension."""
302462
for loader in [
@@ -308,3 +468,17 @@ def register_data_loaders():
308468

309469

310470
register_data_loaders()
471+
472+
473+
def register_data_savers():
474+
"""Function to register the data savers for this extension."""
475+
for saver in [
476+
PolarsSinkParquetWriter,
477+
PolarsSinkCSVWriter,
478+
PolarsSinkIPCWriter,
479+
PolarsSinkNDJSONWriter,
480+
]:
481+
registry.register_adapter(saver)
482+
483+
484+
register_data_savers()

0 commit comments

Comments
 (0)