diff --git a/doc/release_notes.rst b/doc/release_notes.rst index 22c0ffee..dcd165bd 100644 --- a/doc/release_notes.rst +++ b/doc/release_notes.rst @@ -8,6 +8,7 @@ Upcoming Version * Reduced memory usage and faster file I/O operations when exporting models to LP format * Improved constraint equality check in `linopy.testing.assert_conequal` to less strict optionally * Minor bugfix for multiplying variables with numpy type constants +* Harmonize dtypes before concatenation in lp file writing to avoid dtype mismatch errors. This error occurred when creating and storing models in netcdf format using windows machines and loading and solving them on linux machines. Version 0.5.6 -------------- diff --git a/linopy/common.py b/linopy/common.py index 6804ac1e..1f310fbc 100644 --- a/linopy/common.py +++ b/linopy/common.py @@ -442,6 +442,124 @@ def group_terms_polars(df: pl.DataFrame) -> pl.DataFrame: return df +def harmonize_polars_dtypes( + df1: pl.DataFrame, df2: pl.DataFrame +) -> tuple[pl.DataFrame, pl.DataFrame]: + """ + Harmonize dtypes of overlapping columns between two polars DataFrames. + + For columns that appear in both dataframes but have different dtypes, + this function upcasts them to a compatible common dtype to avoid + concatenation errors. + + Args: + ---- + df1 (pl.DataFrame): First DataFrame + df2 (pl.DataFrame): Second DataFrame + + Returns: + ------- + tuple[pl.DataFrame, pl.DataFrame]: Both DataFrames with harmonized dtypes + """ + # Find overlapping columns + common_cols = set(df1.columns) & set(df2.columns) + + if not common_cols: + return df1, df2 + + # Build casting maps for both dataframes + cast_map1: dict[str, Any] = {} + cast_map2: dict[str, Any] = {} + + for col in common_cols: + dtype1 = df1[col].dtype + dtype2 = df2[col].dtype + + if dtype1 == dtype2: + continue + + # Determine the common dtype by attempting to upcast + target_dtype = _get_common_polars_dtype(dtype1, dtype2) + + if target_dtype is not None: + if dtype1 != target_dtype: + cast_map1[col] = target_dtype + if dtype2 != target_dtype: + cast_map2[col] = target_dtype + + # Apply casts if needed + if cast_map1: + df1 = df1.cast(cast_map1) # type: ignore + if cast_map2: + df2 = df2.cast(cast_map2) # type: ignore + + return df1, df2 + + +def _get_common_polars_dtype(dtype1: pl.DataType, dtype2: pl.DataType) -> Any: + """ + Get the common dtype for two polars dtypes by upcasting. + + Args: + ---- + dtype1: First dtype + dtype2: Second dtype + + Returns: + ------- + pl.DataType | None: The common dtype, or None if no upcasting is needed + """ + # If dtypes are the same, no conversion needed + if dtype1 == dtype2: + return None + + # Handle numeric types + if dtype1.is_numeric() and dtype2.is_numeric(): + # Define type hierarchy (smaller to larger) + int_hierarchy = [pl.Int8, pl.Int16, pl.Int32, pl.Int64] + uint_hierarchy = [pl.UInt8, pl.UInt16, pl.UInt32, pl.UInt64] + + # Check if both are integers (signed or unsigned) + if dtype1.is_integer() and dtype2.is_integer(): + # If mixing signed and unsigned, go to signed with larger width + all_types = int_hierarchy + uint_hierarchy + + # Find positions in combined hierarchy + idx1 = next((i for i, t in enumerate(all_types) if dtype1 == t), -1) + idx2 = next((i for i, t in enumerate(all_types) if dtype2 == t), -1) + + if idx1 >= 0 and idx2 >= 0: + # If mixing signed/unsigned, prefer signed and take the larger + if (dtype1 in int_hierarchy) != (dtype2 in int_hierarchy): + # Mixed signed/unsigned - use Int64 to be safe + return pl.Int64 + else: + # Same signedness, use the wider type + wider_idx = max(idx1, idx2) + return all_types[wider_idx] + + # Fallback to Int64 + return pl.Int64 + + # If one is float and one is int, use float + if dtype1.is_float() or dtype2.is_float(): + # Use the wider float type + if dtype1 == pl.Float64 or dtype2 == pl.Float64: + return pl.Float64 + return pl.Float32 + + # For non-numeric types, prefer String/Utf8 as a safe common type + if dtype1 == pl.Utf8 or dtype2 == pl.Utf8: + return pl.Utf8 + + if dtype1 == pl.String or dtype2 == pl.String: + return pl.String + + # If we can't determine a common type, use the first dtype + # This maintains backward compatibility + return dtype1 + + def save_join(*dataarrays: DataArray, integer_dtype: bool = False) -> Dataset: """ Join multiple xarray Dataarray's to a Dataset and warn if coordinates are not equal. diff --git a/linopy/constraints.py b/linopy/constraints.py index 15bf93f0..021076ea 100644 --- a/linopy/constraints.py +++ b/linopy/constraints.py @@ -40,6 +40,7 @@ get_dims_with_index_levels, get_label_position, group_terms_polars, + harmonize_polars_dtypes, has_optimized_model, infer_schema_polars, is_constant, @@ -631,6 +632,9 @@ def to_polars(self) -> pl.DataFrame: short = filter_nulls_polars(short) check_has_nulls_polars(short, name=f"{self.type} {self.name}") + # Harmonize dtypes before concatenation to avoid dtype mismatch errors + short, long = harmonize_polars_dtypes(short, long) + df = pl.concat([short, long], how="diagonal").sort(["labels", "rhs"]) # delete subsequent non-null rhs (happens is all vars per label are -1) is_non_null = df["rhs"].is_not_null() diff --git a/test/test_common.py b/test/test_common.py index 85059487..f1ec9a80 100644 --- a/test/test_common.py +++ b/test/test_common.py @@ -7,6 +7,7 @@ import numpy as np import pandas as pd +import polars as pl import pytest import xarray as xr from test_linear_expression import m, u, x # noqa: F401 @@ -20,6 +21,7 @@ assign_multiindex_safe, best_int, get_dims_with_index_levels, + harmonize_polars_dtypes, iterate_slices, ) from linopy.testing import assert_linequal, assert_varequal @@ -700,3 +702,155 @@ def test_align(x: Variable, u: Variable) -> None: # noqa: F811 assert expr_obs.shape == (1, 1) # _term dim assert isinstance(expr_obs, LinearExpression) assert_linequal(expr_obs, expr.loc[[1]]) + + +def test_harmonize_polars_dtypes_same_dtypes() -> None: + """Test when both dataframes have the same dtypes - no changes needed.""" + df1 = pl.DataFrame({"a": [1, 2, 3], "b": [4.0, 5.0, 6.0]}) + df2 = pl.DataFrame({"a": [7, 8, 9], "b": [10.0, 11.0, 12.0]}) + + result1, result2 = harmonize_polars_dtypes(df1, df2) + + assert result1.dtypes == df1.dtypes + assert result2.dtypes == df2.dtypes + + +def test_harmonize_polars_dtypes_no_overlap() -> None: + """Test when dataframes have no overlapping columns.""" + df1 = pl.DataFrame({"a": [1, 2, 3]}) + df2 = pl.DataFrame({"b": [4, 5, 6]}) + + result1, result2 = harmonize_polars_dtypes(df1, df2) + + assert result1.dtypes == df1.dtypes + assert result2.dtypes == df2.dtypes + + +def test_harmonize_polars_dtypes_int_upcast() -> None: + """Test upcasting different integer types.""" + df1 = pl.DataFrame({"a": pl.Series([1, 2, 3], dtype=pl.Int32)}) + df2 = pl.DataFrame({"a": pl.Series([4, 5, 6], dtype=pl.Int64)}) + + result1, result2 = harmonize_polars_dtypes(df1, df2) + + # Both should be Int64 (the wider type) + assert result1["a"].dtype == pl.Int64 + assert result2["a"].dtype == pl.Int64 + + +def test_harmonize_polars_dtypes_mixed_signed_unsigned() -> None: + """Test mixing signed and unsigned integers.""" + df1 = pl.DataFrame({"a": pl.Series([1, 2, 3], dtype=pl.Int32)}) + df2 = pl.DataFrame({"a": pl.Series([4, 5, 6], dtype=pl.UInt32)}) + + result1, result2 = harmonize_polars_dtypes(df1, df2) + + # Should upcast to signed Int64 to be safe + assert result1["a"].dtype == pl.Int64 + assert result2["a"].dtype == pl.Int64 + + +def test_harmonize_polars_dtypes_int_to_float() -> None: + """Test upcasting integer to float when mixed.""" + df1 = pl.DataFrame({"a": pl.Series([1, 2, 3], dtype=pl.Int32)}) + df2 = pl.DataFrame({"a": pl.Series([4.0, 5.0, 6.0], dtype=pl.Float64)}) + + result1, result2 = harmonize_polars_dtypes(df1, df2) + + # Both should be Float64 + assert result1["a"].dtype == pl.Float64 + assert result2["a"].dtype == pl.Float64 + + +def test_harmonize_polars_dtypes_float_upcast() -> None: + """Test upcasting Float32 to Float64.""" + df1 = pl.DataFrame({"a": pl.Series([1.0, 2.0, 3.0], dtype=pl.Float32)}) + df2 = pl.DataFrame({"a": pl.Series([4.0, 5.0, 6.0], dtype=pl.Float64)}) + + result1, result2 = harmonize_polars_dtypes(df1, df2) + + # Both should be Float64 + assert result1["a"].dtype == pl.Float64 + assert result2["a"].dtype == pl.Float64 + + +def test_harmonize_polars_dtypes_multiple_columns() -> None: + """Test harmonizing multiple columns with different dtype mismatches.""" + df1 = pl.DataFrame( + { + "a": pl.Series([1, 2], dtype=pl.Int32), + "b": pl.Series([3.0, 4.0], dtype=pl.Float32), + "c": pl.Series([5, 6], dtype=pl.Int64), + } + ) + df2 = pl.DataFrame( + { + "a": pl.Series([7, 8], dtype=pl.Int64), + "b": pl.Series([9.0, 10.0], dtype=pl.Float64), + "c": pl.Series([11, 12], dtype=pl.Int64), + } + ) + + result1, result2 = harmonize_polars_dtypes(df1, df2) + + # "a" should be upcasted to Int64 + assert result1["a"].dtype == pl.Int64 + assert result2["a"].dtype == pl.Int64 + + # "b" should be upcasted to Float64 + assert result1["b"].dtype == pl.Float64 + assert result2["b"].dtype == pl.Float64 + + # "c" should remain Int64 (already the same) + assert result1["c"].dtype == pl.Int64 + assert result2["c"].dtype == pl.Int64 + + +def test_harmonize_polars_dtypes_partial_overlap() -> None: + """Test harmonizing when only some columns overlap.""" + df1 = pl.DataFrame( + {"a": pl.Series([1, 2], dtype=pl.Int32), "b": [3.0, 4.0], "d": [5, 6]} + ) + df2 = pl.DataFrame( + {"a": pl.Series([7, 8], dtype=pl.Int64), "b": [9.0, 10.0], "e": [11, 12]} + ) + + result1, result2 = harmonize_polars_dtypes(df1, df2) + + # Overlapping columns should be harmonized + assert result1["a"].dtype == pl.Int64 + assert result2["a"].dtype == pl.Int64 + + # Non-overlapping columns should remain unchanged + assert "d" in result1.columns + assert "d" not in result2.columns + assert "e" not in result1.columns + assert "e" in result2.columns + + +def test_harmonize_polars_dtypes_concat_compatibility() -> None: + """Test that harmonized dataframes can be successfully concatenated.""" + # This simulates the actual use case in the to_polars function + df1 = pl.DataFrame( + { + "labels": pl.Series([0, 1], dtype=pl.Int32), + "coeffs": [1.0, 2.0], + "vars": pl.Series([10, 20], dtype=pl.Int32), + } + ) + df2 = pl.DataFrame( + { + "labels": pl.Series([2, 3], dtype=pl.Int64), + "coeffs": [3.0, 4.0], + "vars": pl.Series([30, 40], dtype=pl.Int64), + } + ) + + result1, result2 = harmonize_polars_dtypes(df1, df2) + + # This should not raise an error + concatenated = pl.concat([result1, result2], how="diagonal") + + assert len(concatenated) == 4 + assert concatenated["labels"].dtype == pl.Int64 + assert concatenated["vars"].dtype == pl.Int64