diff --git a/linopy/io.py b/linopy/io.py index 56fe033d..cd83f6b8 100644 --- a/linopy/io.py +++ b/linopy/io.py @@ -54,6 +54,34 @@ def clean_name(name: str) -> str: coord_sanitizer = str.maketrans("[,]", "(,)", " ") +def _format_and_write( + df: pl.DataFrame, columns: list[pl.Expr], f: BufferedWriter +) -> None: + """ + Format columns via concat_str and write to file. + + Uses Polars streaming engine for better performance when available, + with automatic fallback to eager evaluation. + """ + kwargs: Any = dict( + separator=" ", null_value="", quote_style="never", include_header=False + ) + try: + formatted = ( + df.lazy() + .select(pl.concat_str(columns, ignore_nulls=True)) + .collect(engine="streaming") + ) + except Exception: + logger.warning( + "Polars streaming engine failed, falling back to eager evaluation. " + "Please report this at https://github.com/PyPSA/linopy/issues", + exc_info=True, + ) + formatted = df.select(pl.concat_str(columns, ignore_nulls=True)) + formatted.write_csv(f, **kwargs) + + def signed_number(expr: pl.Expr) -> tuple[pl.Expr, pl.Expr]: """ Return polars expressions for a signed number string, handling -0.0 correctly. @@ -155,10 +183,7 @@ def objective_write_linear_terms( *signed_number(pl.col("coeffs")), *print_variable(pl.col("vars")), ] - df = df.select(pl.concat_str(cols, ignore_nulls=True)) - df.write_csv( - f, separator=" ", null_value="", quote_style="never", include_header=False - ) + _format_and_write(df, cols, f) def objective_write_quadratic_terms( @@ -171,10 +196,7 @@ def objective_write_quadratic_terms( *print_variable(pl.col("vars2")), ] f.write(b"+ [\n") - df = df.select(pl.concat_str(cols, ignore_nulls=True)) - df.write_csv( - f, separator=" ", null_value="", quote_style="never", include_header=False - ) + _format_and_write(df, cols, f) f.write(b"] / 2\n") @@ -254,11 +276,7 @@ def bounds_to_file( *signed_number(pl.col("upper")), ] - kwargs: Any = dict( - separator=" ", null_value="", quote_style="never", include_header=False - ) - formatted = df.select(pl.concat_str(columns, ignore_nulls=True)) - formatted.write_csv(f, **kwargs) + _format_and_write(df, columns, f) def binaries_to_file( @@ -296,11 +314,7 @@ def binaries_to_file( *print_variable(pl.col("labels")), ] - kwargs: Any = dict( - separator=" ", null_value="", quote_style="never", include_header=False - ) - formatted = df.select(pl.concat_str(columns, ignore_nulls=True)) - formatted.write_csv(f, **kwargs) + _format_and_write(df, columns, f) def integers_to_file( @@ -339,11 +353,7 @@ def integers_to_file( *print_variable(pl.col("labels")), ] - kwargs: Any = dict( - separator=" ", null_value="", quote_style="never", include_header=False - ) - formatted = df.select(pl.concat_str(columns, ignore_nulls=True)) - formatted.write_csv(f, **kwargs) + _format_and_write(df, columns, f) def sos_to_file( @@ -399,11 +409,7 @@ def sos_to_file( pl.col("var_weights"), ] - kwargs: Any = dict( - separator=" ", null_value="", quote_style="never", include_header=False - ) - formatted = df.select(pl.concat_str(columns, ignore_nulls=True)) - formatted.write_csv(f, **kwargs) + _format_and_write(df, columns, f) def constraints_to_file( @@ -487,11 +493,7 @@ def constraints_to_file( pl.when(pl.col("is_last_in_group")).then(pl.col("rhs").cast(pl.String)), ] - kwargs: Any = dict( - separator=" ", null_value="", quote_style="never", include_header=False - ) - formatted = df.select(pl.concat_str(columns, ignore_nulls=True)) - formatted.write_csv(f, **kwargs) + _format_and_write(df, columns, f) # in the future, we could use lazy dataframes when they support appending # tp existent files