Skip to content

Commit 0af591f

Browse files
committed
1. Moved cumsum, groupby, and rolling methods from BaseExpression to LinearExpression
2. QuadraticExpression now defines its own versions without overriding anything from BaseExpression 3. No more Liskov substitution principle violations 4. No # type: ignore hacks needed
1 parent 3c7a48a commit 0af591f

1 file changed

Lines changed: 113 additions & 117 deletions

File tree

linopy/expressions.py

Lines changed: 113 additions & 117 deletions
Original file line numberDiff line numberDiff line change
@@ -796,55 +796,6 @@ def sum(
796796

797797
return res
798798

799-
def cumsum(
800-
self,
801-
dim: DimsLike | None = None,
802-
*,
803-
skipna: bool | None = None,
804-
keep_attrs: bool | None = None,
805-
**kwargs: Any,
806-
) -> LinearExpression:
807-
"""
808-
Cumulated sum along a given axis.
809-
810-
Docstring and arguments are borrowed from `xarray.Dataset.cumsum`
811-
812-
Parameters
813-
----------
814-
dim : str, Iterable of Hashable, "..." or None, default: None
815-
Name of dimension[s] along which to apply ``cumsum``. For e.g. ``dim="x"``
816-
or ``dim=["x", "y"]``. If "..." or None, will reduce over all dimensions.
817-
skipna : bool or None, optional
818-
If True, skip missing values (as marked by NaN). By default, only
819-
skips missing values for float dtypes; other dtypes either do not
820-
have a sentinel missing value (int) or ``skipna=True`` has not been
821-
implemented (object, datetime64 or timedelta64).
822-
keep_attrs : bool or None, optional
823-
If True, ``attrs`` will be copied from the original
824-
object to the new one. If False, the new object will be
825-
returned without attributes.
826-
**kwargs : Any
827-
Additional keyword arguments passed on to the appropriate array
828-
function for calculating ``cumsum`` on this object's data.
829-
These could include dask-specific kwargs like ``split_every``.
830-
831-
Returns
832-
-------
833-
linopy.expression.LinearExpression
834-
"""
835-
# Along every dimensions, we want to perform cumsum along, get the size of the
836-
# dimension to pass that to self.rolling.
837-
if not dim:
838-
# If user did not specify a dimension to sum over, use all relevant
839-
# dimensions
840-
dim = self.coord_dims
841-
if isinstance(dim, str):
842-
dim = [dim]
843-
elif isinstance(dim, EllipsisType) or dim is None:
844-
dim = self.coord_dims
845-
dim_dict = {dim_name: self.data.sizes[dim_name] for dim_name in dim}
846-
return self.rolling(dim=dim_dict).sum(keep_attrs=keep_attrs, skipna=skipna)
847-
848799
def to_constraint(
849800
self, sign: SignLike, rhs: ConstantLike | VariableLike | ExpressionLike
850801
) -> Constraint:
@@ -991,74 +942,6 @@ def diff(self: GenericExpression, dim: str, n: int = 1) -> GenericExpression:
991942
"""
992943
return self - self.shift({dim: n})
993944

994-
def groupby(
995-
self,
996-
group: DataFrame | Series | DataArray,
997-
restore_coord_dims: bool | None = None,
998-
**kwargs: Any,
999-
) -> LinearExpressionGroupby:
1000-
"""
1001-
Returns a LinearExpressionGroupBy object for performing grouped
1002-
operations.
1003-
1004-
Docstring and arguments are borrowed from `xarray.Dataset.groupby`
1005-
1006-
Parameters
1007-
----------
1008-
group : str, DataArray or IndexVariable
1009-
Array whose unique values should be used to group this array. If a
1010-
string, must be the name of a variable contained in this dataset.
1011-
restore_coord_dims : bool, optional
1012-
If True, also restore the dimension order of multi-dimensional
1013-
coordinates.
1014-
1015-
Returns
1016-
-------
1017-
grouped
1018-
A `LinearExpressionGroupBy` containing the xarray groups and ensuring
1019-
the correct return type.
1020-
"""
1021-
ds = self.data
1022-
kwargs = dict(restore_coord_dims=restore_coord_dims, **kwargs)
1023-
return LinearExpressionGroupby(ds, group, model=self.model, kwargs=kwargs)
1024-
1025-
def rolling(
1026-
self,
1027-
dim: Mapping[Any, int] | None = None,
1028-
min_periods: int | None = None,
1029-
center: bool | Mapping[Any, bool] = False,
1030-
**window_kwargs: int,
1031-
) -> LinearExpressionRolling:
1032-
"""
1033-
Rolling window object.
1034-
1035-
Docstring and arguments are borrowed from `xarray.Dataset.rolling`
1036-
1037-
Parameters
1038-
----------
1039-
dim : dict, optional
1040-
Mapping from the dimension name to create the rolling iterator
1041-
along (e.g. `time`) to its moving window size.
1042-
min_periods : int, default: None
1043-
Minimum number of observations in window required to have a value
1044-
(otherwise result is NA). The default, None, is equivalent to
1045-
setting min_periods equal to the size of the window.
1046-
center : bool or mapping, default: False
1047-
Set the labels at the center of the window.
1048-
**window_kwargs : optional
1049-
The keyword arguments form of ``dim``.
1050-
One of dim or window_kwargs must be provided.
1051-
1052-
Returns
1053-
-------
1054-
linopy.expression.LinearExpressionRolling
1055-
"""
1056-
ds = self.data
1057-
rolling = ds.rolling(
1058-
dim=dim, min_periods=min_periods, center=center, **window_kwargs
1059-
)
1060-
return LinearExpressionRolling(rolling, model=self.model)
1061-
1062945
@property
1063946
def nterm(self) -> int:
1064947
"""
@@ -1651,6 +1534,119 @@ def process_one(
16511534

16521535
return merge(exprs, cls=cls) if len(exprs) > 1 else exprs[0]
16531536

1537+
def cumsum(
1538+
self,
1539+
dim: DimsLike | None = None,
1540+
*,
1541+
skipna: bool | None = None,
1542+
keep_attrs: bool | None = None,
1543+
**kwargs: Any,
1544+
) -> LinearExpression:
1545+
"""
1546+
Cumulated sum along a given axis.
1547+
1548+
Docstring and arguments are borrowed from `xarray.Dataset.cumsum`
1549+
1550+
Parameters
1551+
----------
1552+
dim : str, Iterable of Hashable, "..." or None, default: None
1553+
Name of dimension[s] along which to apply ``cumsum``. For e.g. ``dim="x"``
1554+
or ``dim=["x", "y"]``. If "..." or None, will reduce over all dimensions.
1555+
skipna : bool or None, optional
1556+
If True, skip missing values (as marked by NaN). By default, only
1557+
skips missing values for float dtypes; other dtypes either do not
1558+
have a sentinel missing value (int) or ``skipna=True`` has not been
1559+
implemented (object, datetime64 or timedelta64).
1560+
keep_attrs : bool or None, optional
1561+
If True, ``attrs`` will be copied from the original
1562+
object to the new one. If False, the new object will be
1563+
returned without attributes.
1564+
**kwargs : Any
1565+
Additional keyword arguments passed on to the appropriate array
1566+
function for calculating ``cumsum`` on this object's data.
1567+
These could include dask-specific kwargs like ``split_every``.
1568+
1569+
Returns
1570+
-------
1571+
linopy.expression.LinearExpression
1572+
"""
1573+
if not dim:
1574+
dim = self.coord_dims
1575+
if isinstance(dim, str):
1576+
dim = [dim]
1577+
elif isinstance(dim, EllipsisType) or dim is None:
1578+
dim = self.coord_dims
1579+
dim_dict = {dim_name: self.data.sizes[dim_name] for dim_name in dim}
1580+
return self.rolling(dim=dim_dict).sum(keep_attrs=keep_attrs, skipna=skipna)
1581+
1582+
def groupby(
1583+
self,
1584+
group: DataFrame | Series | DataArray,
1585+
restore_coord_dims: bool | None = None,
1586+
**kwargs: Any,
1587+
) -> LinearExpressionGroupby:
1588+
"""
1589+
Returns a LinearExpressionGroupBy object for performing grouped
1590+
operations.
1591+
1592+
Docstring and arguments are borrowed from `xarray.Dataset.groupby`
1593+
1594+
Parameters
1595+
----------
1596+
group : str, DataArray or IndexVariable
1597+
Array whose unique values should be used to group this array. If a
1598+
string, must be the name of a variable contained in this dataset.
1599+
restore_coord_dims : bool, optional
1600+
If True, also restore the dimension order of multi-dimensional
1601+
coordinates.
1602+
1603+
Returns
1604+
-------
1605+
grouped
1606+
A `LinearExpressionGroupBy` containing the xarray groups and ensuring
1607+
the correct return type.
1608+
"""
1609+
ds = self.data
1610+
kwargs = dict(restore_coord_dims=restore_coord_dims, **kwargs)
1611+
return LinearExpressionGroupby(ds, group, model=self.model, kwargs=kwargs)
1612+
1613+
def rolling(
1614+
self,
1615+
dim: Mapping[Any, int] | None = None,
1616+
min_periods: int | None = None,
1617+
center: bool | Mapping[Any, bool] = False,
1618+
**window_kwargs: int,
1619+
) -> LinearExpressionRolling:
1620+
"""
1621+
Rolling window object.
1622+
1623+
Docstring and arguments are borrowed from `xarray.Dataset.rolling`
1624+
1625+
Parameters
1626+
----------
1627+
dim : dict, optional
1628+
Mapping from the dimension name to create the rolling iterator
1629+
along (e.g. `time`) to its moving window size.
1630+
min_periods : int, default: None
1631+
Minimum number of observations in window required to have a value
1632+
(otherwise result is NA). The default, None, is equivalent to
1633+
setting min_periods equal to the size of the window.
1634+
center : bool or mapping, default: False
1635+
Set the labels at the center of the window.
1636+
**window_kwargs : optional
1637+
The keyword arguments form of ``dim``.
1638+
One of dim or window_kwargs must be provided.
1639+
1640+
Returns
1641+
-------
1642+
linopy.expression.LinearExpressionRolling
1643+
"""
1644+
ds = self.data
1645+
rolling = ds.rolling(
1646+
dim=dim, min_periods=min_periods, center=center, **window_kwargs
1647+
)
1648+
return LinearExpressionRolling(rolling, model=self.model)
1649+
16541650

16551651
class QuadraticExpression(BaseExpression):
16561652
"""

0 commit comments

Comments
 (0)