Skip to content

Commit 58879fc

Browse files
committed
fix typing errors
1 parent 19125ac commit 58879fc

12 files changed

Lines changed: 133 additions & 103 deletions

File tree

linopy/common.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import pandas as pd
2020
import polars as pl
2121
from numpy import arange, signedinteger
22+
from polars.datatypes import DataTypeClass
2223
from xarray import DataArray, Dataset, apply_ufunc, broadcast
2324
from xarray import align as xr_align
2425
from xarray.core import dtypes, indexing
@@ -327,7 +328,7 @@ def check_has_nulls(df: pd.DataFrame, name: str) -> None:
327328
raise ValueError(f"Fields {name} contains nan's in field(s) {fields}")
328329

329330

330-
def infer_schema_polars(ds: Dataset) -> dict[Hashable, pl.DataType]:
331+
def infer_schema_polars(ds: Dataset) -> dict[str, DataTypeClass]:
331332
"""
332333
Infer the polars data schema from a xarray dataset.
333334
@@ -339,21 +340,22 @@ def infer_schema_polars(ds: Dataset) -> dict[Hashable, pl.DataType]:
339340
-------
340341
dict: A dictionary mapping column names to their corresponding Polars data types.
341342
"""
342-
schema = {}
343+
schema: dict[str, DataTypeClass] = {}
343344
np_major_version = int(np.__version__.split(".")[0])
344345
use_int32 = os.name == "nt" and np_major_version < 2
345346
for name, array in ds.items():
347+
name = str(name)
346348
if np.issubdtype(array.dtype, np.integer):
347349
schema[name] = pl.Int32 if use_int32 else pl.Int64
348350
elif np.issubdtype(array.dtype, np.floating):
349-
schema[name] = pl.Float64 # type: ignore
351+
schema[name] = pl.Float64
350352
elif np.issubdtype(array.dtype, np.bool_):
351-
schema[name] = pl.Boolean # type: ignore
353+
schema[name] = pl.Boolean
352354
elif np.issubdtype(array.dtype, np.object_):
353-
schema[name] = pl.Object # type: ignore
355+
schema[name] = pl.Object
354356
else:
355-
schema[name] = pl.Utf8 # type: ignore
356-
return schema # type: ignore
357+
schema[name] = pl.Utf8
358+
return schema
357359

358360

359361
def to_polars(ds: Dataset, **kwargs: Any) -> pl.DataFrame:
@@ -429,7 +431,7 @@ def filter_nulls_polars(df: pl.DataFrame) -> pl.DataFrame:
429431
if "labels" in df.columns:
430432
cond.append(pl.col("labels").ne(-1))
431433

432-
cond = reduce(operator.and_, cond) # type: ignore
434+
cond = reduce(operator.and_, cond) # type: ignore[arg-type]
433435
return df.filter(cond)
434436

435437

@@ -655,7 +657,7 @@ def iterate_slices(
655657
start = i * chunk_size
656658
end = min(start + chunk_size, size_of_leading_dim)
657659
slice_dict = {leading_dim: slice(start, end)}
658-
yield ds.isel(slice_dict)
660+
yield ds.isel(slice_dict) # type: ignore[attr-defined]
659661

660662

661663
def _remap(array: np.ndarray, mapping: np.ndarray) -> np.ndarray:
@@ -1367,7 +1369,7 @@ def __getitem__(
13671369
# expand the indexer so we can handle Ellipsis
13681370
labels = indexing.expanded_indexer(key, self.object.ndim)
13691371
key = dict(zip(self.object.dims, labels))
1370-
return self.object.sel(key)
1372+
return self.object.sel(key) # type: ignore[attr-defined]
13711373

13721374

13731375
class EmptyDeprecationWrapper:
@@ -1439,9 +1441,9 @@ def coords_from_dataset(ds: Dataset, coord_dims: list[str]) -> list[pd.Index]:
14391441
if f"_coord_{d}_codes" in ds:
14401442
codes_2d = ds[f"_coord_{d}_codes"].values.T
14411443
level_names = [
1442-
k[len(f"_coord_{d}_level_") :]
1444+
str(k)[len(f"_coord_{d}_level_") :]
14431445
for k in ds
1444-
if k.startswith(f"_coord_{d}_level_")
1446+
if str(k).startswith(f"_coord_{d}_level_")
14451447
]
14461448
arrays = [
14471449
ds[f"_coord_{d}_level_{ln}"].values[codes_2d[i]]

linopy/constants.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -214,9 +214,11 @@ def process(cls, status: str, termination_condition: str) -> "Status":
214214

215215
@classmethod
216216
def from_termination_condition(
217-
cls, termination_condition: Union["TerminationCondition", str]
217+
cls, termination_condition: Union["TerminationCondition", str, None]
218218
) -> "Status":
219-
termination_condition = TerminationCondition.process(termination_condition)
219+
termination_condition = TerminationCondition.process(
220+
termination_condition if termination_condition is not None else "unknown"
221+
)
220222
solver_status = SolverStatus.from_termination_condition(termination_condition)
221223
return cls(solver_status, termination_condition)
222224

linopy/constraints.py

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,11 @@ def rhs(self) -> DataArray:
169169
def dual(self) -> DataArray:
170170
"""Get the dual values DataArray."""
171171

172+
@dual.setter
173+
@abstractmethod
174+
def dual(self, value: DataArray) -> None:
175+
"""Set the dual values DataArray."""
176+
172177
@abstractmethod
173178
def has_variable(self, variable: variables.Variable) -> bool:
174179
"""Check if the constraint references any of the given variable labels."""
@@ -185,6 +190,18 @@ def sanitize_missings(self) -> ConstraintBase:
185190
def sanitize_infinities(self) -> ConstraintBase:
186191
"""Mask out rows with invalid infinite RHS values."""
187192

193+
@abstractmethod
194+
def to_polars(self) -> pl.DataFrame:
195+
"""Convert constraint to a polars DataFrame."""
196+
197+
@abstractmethod
198+
def freeze(self) -> Constraint:
199+
"""Return an immutable Constraint (CSR-backed)."""
200+
201+
@abstractmethod
202+
def mutable(self) -> MutableConstraint:
203+
"""Return a mutable MutableConstraint."""
204+
188205
@abstractmethod
189206
def to_matrix_with_rhs(
190207
self, label_index: VariableLabelIndex
@@ -298,7 +315,8 @@ def mask(self) -> DataArray | None:
298315
(True) and disabled (False).
299316
"""
300317
if self.is_assigned:
301-
return (self.labels != FILL_VALUE["labels"]).astype(bool)
318+
result: DataArray = self.labels != FILL_VALUE["labels"] # type: ignore[assignment]
319+
return result.astype(bool)
302320
return None
303321

304322
@property
@@ -391,7 +409,7 @@ def flat(self) -> pd.DataFrame:
391409
"""
392410
ds = self.data
393411

394-
def mask_func(data: pd.DataFrame) -> pd.Series:
412+
def mask_func(data: dict) -> pd.Series:
395413
mask = (data["vars"] != -1) & (data["coeffs"] != 0)
396414
if "labels" in data:
397415
mask &= data["labels"] != -1
@@ -593,7 +611,7 @@ def nterm(self) -> int:
593611

594612
@property
595613
def coord_names(self) -> list[str]:
596-
return [c.name for c in self._coords]
614+
return [str(c.name) for c in self._coords]
597615

598616
@property
599617
def labels(self) -> DataArray:
@@ -828,8 +846,8 @@ def from_netcdf_ds(cls, ds: Dataset, model: Model, name: str) -> Constraint:
828846
)
829847
rhs = ds["rhs"].values
830848
sign = attrs["sign"]
831-
cindex = int(attrs["cindex"])
832-
cindex = cindex if cindex >= 0 else None
849+
_cindex_raw = int(attrs["cindex"])
850+
cindex: int | None = _cindex_raw if _cindex_raw >= 0 else None
833851
coord_dims = attrs["coord_dims"]
834852
if isinstance(coord_dims, str):
835853
coord_dims = [coord_dims]
@@ -892,7 +910,7 @@ def mutable(self) -> MutableConstraint:
892910
"""Convert to a MutableConstraint."""
893911
return MutableConstraint(self.data, self._model, self._name)
894912

895-
def to_polars(self) -> Any:
913+
def to_polars(self) -> pl.DataFrame:
896914
"""Convert to polars DataFrame — delegates to mutable()."""
897915
return self.mutable().to_polars()
898916

@@ -1598,6 +1616,8 @@ def set_blocks(self, block_map: np.ndarray) -> None:
15981616
N = block_map.max()
15991617

16001618
for name, constraint in self.items():
1619+
if not isinstance(constraint, MutableConstraint):
1620+
self.data[name] = constraint = constraint.mutable()
16011621
res = xr.full_like(constraint.labels, N + 1, dtype=block_map.dtype)
16021622
entries = replace_by_map(constraint.vars, block_map)
16031623

@@ -1679,9 +1699,12 @@ def reset_dual(self) -> None:
16791699
cindex=c._cindex,
16801700
dual=None,
16811701
)
1682-
else:
1702+
elif isinstance(c, MutableConstraint):
16831703
if "dual" in c.data:
16841704
c._data = c.data.drop_vars("dual")
1705+
else:
1706+
msg = f"reset_dual encountered an unknown constraint type: {type(c)}"
1707+
raise NotImplementedError(msg)
16851708

16861709

16871710
class AnonymousScalarConstraint:

0 commit comments

Comments
 (0)