Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/release_notes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ Upcoming Version
* Reduced memory usage and faster file I/O operations when exporting models to LP format
* Improved constraint equality check in `linopy.testing.assert_conequal` to less strict optionally
* Minor bugfix for multiplying variables with numpy type constants
* Harmonize dtypes before concatenation in lp file writing to avoid dtype mismatch errors. This error occurred when creating and storing models in netcdf format using windows machines and loading and solving them on linux machines.

Version 0.5.6
--------------
Expand Down
118 changes: 118 additions & 0 deletions linopy/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,6 +442,124 @@ def group_terms_polars(df: pl.DataFrame) -> pl.DataFrame:
return df


def harmonize_polars_dtypes(
df1: pl.DataFrame, df2: pl.DataFrame
) -> tuple[pl.DataFrame, pl.DataFrame]:
"""
Harmonize dtypes of overlapping columns between two polars DataFrames.

For columns that appear in both dataframes but have different dtypes,
this function upcasts them to a compatible common dtype to avoid
concatenation errors.

Args:
----
df1 (pl.DataFrame): First DataFrame
df2 (pl.DataFrame): Second DataFrame

Returns:
-------
tuple[pl.DataFrame, pl.DataFrame]: Both DataFrames with harmonized dtypes
"""
# Find overlapping columns
common_cols = set(df1.columns) & set(df2.columns)

if not common_cols:
return df1, df2

# Build casting maps for both dataframes
cast_map1: dict[str, Any] = {}
cast_map2: dict[str, Any] = {}

for col in common_cols:
dtype1 = df1[col].dtype
dtype2 = df2[col].dtype

if dtype1 == dtype2:
continue

# Determine the common dtype by attempting to upcast
target_dtype = _get_common_polars_dtype(dtype1, dtype2)

if target_dtype is not None:
if dtype1 != target_dtype:
cast_map1[col] = target_dtype
if dtype2 != target_dtype:
cast_map2[col] = target_dtype

# Apply casts if needed
if cast_map1:
df1 = df1.cast(cast_map1) # type: ignore
if cast_map2:
df2 = df2.cast(cast_map2) # type: ignore

return df1, df2


def _get_common_polars_dtype(dtype1: pl.DataType, dtype2: pl.DataType) -> Any:
"""
Get the common dtype for two polars dtypes by upcasting.

Args:
----
dtype1: First dtype
dtype2: Second dtype

Returns:
-------
pl.DataType | None: The common dtype, or None if no upcasting is needed
"""
# If dtypes are the same, no conversion needed
if dtype1 == dtype2:
return None

# Handle numeric types
if dtype1.is_numeric() and dtype2.is_numeric():
# Define type hierarchy (smaller to larger)
int_hierarchy = [pl.Int8, pl.Int16, pl.Int32, pl.Int64]
uint_hierarchy = [pl.UInt8, pl.UInt16, pl.UInt32, pl.UInt64]

# Check if both are integers (signed or unsigned)
if dtype1.is_integer() and dtype2.is_integer():
# If mixing signed and unsigned, go to signed with larger width
all_types = int_hierarchy + uint_hierarchy

# Find positions in combined hierarchy
idx1 = next((i for i, t in enumerate(all_types) if dtype1 == t), -1)
idx2 = next((i for i, t in enumerate(all_types) if dtype2 == t), -1)

if idx1 >= 0 and idx2 >= 0:
# If mixing signed/unsigned, prefer signed and take the larger
if (dtype1 in int_hierarchy) != (dtype2 in int_hierarchy):
# Mixed signed/unsigned - use Int64 to be safe
return pl.Int64
else:
# Same signedness, use the wider type
wider_idx = max(idx1, idx2)
return all_types[wider_idx]

# Fallback to Int64
return pl.Int64

# If one is float and one is int, use float
if dtype1.is_float() or dtype2.is_float():
# Use the wider float type
if dtype1 == pl.Float64 or dtype2 == pl.Float64:
return pl.Float64
return pl.Float32

# For non-numeric types, prefer String/Utf8 as a safe common type
if dtype1 == pl.Utf8 or dtype2 == pl.Utf8:
return pl.Utf8

if dtype1 == pl.String or dtype2 == pl.String:
return pl.String

# If we can't determine a common type, use the first dtype
# This maintains backward compatibility
return dtype1


def save_join(*dataarrays: DataArray, integer_dtype: bool = False) -> Dataset:
"""
Join multiple xarray Dataarray's to a Dataset and warn if coordinates are not equal.
Expand Down
4 changes: 4 additions & 0 deletions linopy/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
get_dims_with_index_levels,
get_label_position,
group_terms_polars,
harmonize_polars_dtypes,
has_optimized_model,
infer_schema_polars,
is_constant,
Expand Down Expand Up @@ -631,6 +632,9 @@ def to_polars(self) -> pl.DataFrame:
short = filter_nulls_polars(short)
check_has_nulls_polars(short, name=f"{self.type} {self.name}")

# Harmonize dtypes before concatenation to avoid dtype mismatch errors
short, long = harmonize_polars_dtypes(short, long)

df = pl.concat([short, long], how="diagonal").sort(["labels", "rhs"])
# delete subsequent non-null rhs (happens is all vars per label are -1)
is_non_null = df["rhs"].is_not_null()
Expand Down
154 changes: 154 additions & 0 deletions test/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import numpy as np
import pandas as pd
import polars as pl
import pytest
import xarray as xr
from test_linear_expression import m, u, x # noqa: F401
Expand All @@ -20,6 +21,7 @@
assign_multiindex_safe,
best_int,
get_dims_with_index_levels,
harmonize_polars_dtypes,
iterate_slices,
)
from linopy.testing import assert_linequal, assert_varequal
Expand Down Expand Up @@ -700,3 +702,155 @@ def test_align(x: Variable, u: Variable) -> None: # noqa: F811
assert expr_obs.shape == (1, 1) # _term dim
assert isinstance(expr_obs, LinearExpression)
assert_linequal(expr_obs, expr.loc[[1]])


def test_harmonize_polars_dtypes_same_dtypes() -> None:
"""Test when both dataframes have the same dtypes - no changes needed."""
df1 = pl.DataFrame({"a": [1, 2, 3], "b": [4.0, 5.0, 6.0]})
df2 = pl.DataFrame({"a": [7, 8, 9], "b": [10.0, 11.0, 12.0]})

result1, result2 = harmonize_polars_dtypes(df1, df2)

assert result1.dtypes == df1.dtypes
assert result2.dtypes == df2.dtypes


def test_harmonize_polars_dtypes_no_overlap() -> None:
"""Test when dataframes have no overlapping columns."""
df1 = pl.DataFrame({"a": [1, 2, 3]})
df2 = pl.DataFrame({"b": [4, 5, 6]})

result1, result2 = harmonize_polars_dtypes(df1, df2)

assert result1.dtypes == df1.dtypes
assert result2.dtypes == df2.dtypes


def test_harmonize_polars_dtypes_int_upcast() -> None:
"""Test upcasting different integer types."""
df1 = pl.DataFrame({"a": pl.Series([1, 2, 3], dtype=pl.Int32)})
df2 = pl.DataFrame({"a": pl.Series([4, 5, 6], dtype=pl.Int64)})

result1, result2 = harmonize_polars_dtypes(df1, df2)

# Both should be Int64 (the wider type)
assert result1["a"].dtype == pl.Int64
assert result2["a"].dtype == pl.Int64


def test_harmonize_polars_dtypes_mixed_signed_unsigned() -> None:
"""Test mixing signed and unsigned integers."""
df1 = pl.DataFrame({"a": pl.Series([1, 2, 3], dtype=pl.Int32)})
df2 = pl.DataFrame({"a": pl.Series([4, 5, 6], dtype=pl.UInt32)})

result1, result2 = harmonize_polars_dtypes(df1, df2)

# Should upcast to signed Int64 to be safe
assert result1["a"].dtype == pl.Int64
assert result2["a"].dtype == pl.Int64


def test_harmonize_polars_dtypes_int_to_float() -> None:
"""Test upcasting integer to float when mixed."""
df1 = pl.DataFrame({"a": pl.Series([1, 2, 3], dtype=pl.Int32)})
df2 = pl.DataFrame({"a": pl.Series([4.0, 5.0, 6.0], dtype=pl.Float64)})

result1, result2 = harmonize_polars_dtypes(df1, df2)

# Both should be Float64
assert result1["a"].dtype == pl.Float64
assert result2["a"].dtype == pl.Float64


def test_harmonize_polars_dtypes_float_upcast() -> None:
"""Test upcasting Float32 to Float64."""
df1 = pl.DataFrame({"a": pl.Series([1.0, 2.0, 3.0], dtype=pl.Float32)})
df2 = pl.DataFrame({"a": pl.Series([4.0, 5.0, 6.0], dtype=pl.Float64)})

result1, result2 = harmonize_polars_dtypes(df1, df2)

# Both should be Float64
assert result1["a"].dtype == pl.Float64
assert result2["a"].dtype == pl.Float64


def test_harmonize_polars_dtypes_multiple_columns() -> None:
"""Test harmonizing multiple columns with different dtype mismatches."""
df1 = pl.DataFrame(
{
"a": pl.Series([1, 2], dtype=pl.Int32),
"b": pl.Series([3.0, 4.0], dtype=pl.Float32),
"c": pl.Series([5, 6], dtype=pl.Int64),
}
)
df2 = pl.DataFrame(
{
"a": pl.Series([7, 8], dtype=pl.Int64),
"b": pl.Series([9.0, 10.0], dtype=pl.Float64),
"c": pl.Series([11, 12], dtype=pl.Int64),
}
)

result1, result2 = harmonize_polars_dtypes(df1, df2)

# "a" should be upcasted to Int64
assert result1["a"].dtype == pl.Int64
assert result2["a"].dtype == pl.Int64

# "b" should be upcasted to Float64
assert result1["b"].dtype == pl.Float64
assert result2["b"].dtype == pl.Float64

# "c" should remain Int64 (already the same)
assert result1["c"].dtype == pl.Int64
assert result2["c"].dtype == pl.Int64


def test_harmonize_polars_dtypes_partial_overlap() -> None:
"""Test harmonizing when only some columns overlap."""
df1 = pl.DataFrame(
{"a": pl.Series([1, 2], dtype=pl.Int32), "b": [3.0, 4.0], "d": [5, 6]}
)
df2 = pl.DataFrame(
{"a": pl.Series([7, 8], dtype=pl.Int64), "b": [9.0, 10.0], "e": [11, 12]}
)

result1, result2 = harmonize_polars_dtypes(df1, df2)

# Overlapping columns should be harmonized
assert result1["a"].dtype == pl.Int64
assert result2["a"].dtype == pl.Int64

# Non-overlapping columns should remain unchanged
assert "d" in result1.columns
assert "d" not in result2.columns
assert "e" not in result1.columns
assert "e" in result2.columns


def test_harmonize_polars_dtypes_concat_compatibility() -> None:
"""Test that harmonized dataframes can be successfully concatenated."""
# This simulates the actual use case in the to_polars function
df1 = pl.DataFrame(
{
"labels": pl.Series([0, 1], dtype=pl.Int32),
"coeffs": [1.0, 2.0],
"vars": pl.Series([10, 20], dtype=pl.Int32),
}
)
df2 = pl.DataFrame(
{
"labels": pl.Series([2, 3], dtype=pl.Int64),
"coeffs": [3.0, 4.0],
"vars": pl.Series([30, 40], dtype=pl.Int64),
}
)

result1, result2 = harmonize_polars_dtypes(df1, df2)

# This should not raise an error
concatenated = pl.concat([result1, result2], how="diagonal")

assert len(concatenated) == 4
assert concatenated["labels"].dtype == pl.Int64
assert concatenated["vars"].dtype == pl.Int64
Loading