|
31 | 31 | PWL_LAMBDA_SUFFIX, |
32 | 32 | PWL_SELECT_SUFFIX, |
33 | 33 | PWL_X_LINK_SUFFIX, |
34 | | - PWL_Y_LINK_SUFFIX, |
35 | 34 | SEGMENT_DIM, |
36 | 35 | ) |
37 | 36 |
|
@@ -682,21 +681,17 @@ def add_piecewise_constraints( |
682 | 681 | active_expr = _to_linexpr(active) if active is not None else None |
683 | 682 |
|
684 | 683 | if disjunctive: |
685 | | - # Disjunctive only supports 2-variable for now |
686 | | - if len(coerced) != 2: |
| 684 | + if method == "incremental": |
687 | 685 | raise ValueError( |
688 | | - "Disjunctive piecewise constraints currently support " |
689 | | - "exactly 2 (expression, breakpoints) pairs." |
| 686 | + "Incremental method is not supported for disjunctive constraints" |
690 | 687 | ) |
691 | 688 | return _add_disjunctive( |
692 | 689 | model, |
693 | 690 | name, |
694 | | - lin_exprs[0], |
695 | | - lin_exprs[1], |
696 | | - bp_list[0], |
697 | | - bp_list[1], |
| 691 | + lin_exprs, |
| 692 | + bp_list, |
| 693 | + link_coords, |
698 | 694 | bp_mask, |
699 | | - method, |
700 | 695 | active_expr, |
701 | 696 | ) |
702 | 697 |
|
@@ -901,68 +896,81 @@ def _add_incremental( |
901 | 896 | def _add_disjunctive( |
902 | 897 | model: Model, |
903 | 898 | name: str, |
904 | | - x_expr: LinearExpression, |
905 | | - y_expr: LinearExpression, |
906 | | - x_points: DataArray, |
907 | | - y_points: DataArray, |
908 | | - mask: DataArray | None, |
909 | | - method: str, |
| 899 | + lin_exprs: list[LinearExpression], |
| 900 | + bp_list: list[DataArray], |
| 901 | + link_coords: list[str], |
| 902 | + bp_mask: DataArray | None, |
910 | 903 | active: LinearExpression | None = None, |
911 | 904 | ) -> Constraint: |
912 | | - """Handle disjunctive piecewise equality constraints (2-variable only).""" |
913 | | - if method == "incremental": |
914 | | - raise ValueError( |
915 | | - "Incremental method is not supported for disjunctive constraints" |
916 | | - ) |
| 905 | + """Disjunctive SOS2 formulation for N-variable piecewise equality.""" |
| 906 | + from linopy.expressions import LinearExpression |
917 | 907 |
|
918 | | - _validate_numeric_breakpoint_coords(x_points) |
919 | | - if not _has_trailing_nan_only(x_points): |
| 908 | + link_dim = "_pwl_var" |
| 909 | + stacked_bp = _stack_along_link(bp_list, link_coords, link_dim) |
| 910 | + |
| 911 | + _validate_numeric_breakpoint_coords(stacked_bp) |
| 912 | + if not _has_trailing_nan_only(stacked_bp): |
920 | 913 | raise ValueError( |
921 | 914 | "Disjunctive SOS2 does not support non-trailing NaN breakpoints. " |
922 | 915 | "NaN values must only appear at the end of the breakpoint sequence." |
923 | 916 | ) |
924 | 917 |
|
925 | | - binary_name = f"{name}{PWL_BINARY_SUFFIX}" |
926 | | - select_name = f"{name}{PWL_SELECT_SUFFIX}" |
927 | | - lambda_name = f"{name}{PWL_LAMBDA_SUFFIX}" |
928 | | - convex_name = f"{name}{PWL_CONVEX_SUFFIX}" |
929 | | - x_link_name = f"{name}{PWL_X_LINK_SUFFIX}" |
930 | | - y_link_name = f"{name}{PWL_Y_LINK_SUFFIX}" |
| 918 | + # Stack expressions along link dimension |
| 919 | + stacked_data = _stack_along_link( |
| 920 | + [e.data for e in lin_exprs], link_coords, link_dim |
| 921 | + ) |
| 922 | + target_expr = LinearExpression(stacked_data, model) |
931 | 923 |
|
932 | | - extra = _var_coords_from(x_points, exclude={BREAKPOINT_DIM, SEGMENT_DIM}) |
| 924 | + # Compute stacked mask |
| 925 | + stacked_mask = None |
| 926 | + if bp_mask is not None: |
| 927 | + stacked_mask = _stack_along_link( |
| 928 | + [bp_mask] * len(link_coords), link_coords, link_dim |
| 929 | + ) |
| 930 | + |
| 931 | + dim = BREAKPOINT_DIM |
| 932 | + extra = _var_coords_from(stacked_bp, exclude={dim, SEGMENT_DIM, link_dim}) |
933 | 933 | lambda_coords = extra + [ |
934 | | - pd.Index(x_points.coords[SEGMENT_DIM].values, name=SEGMENT_DIM), |
935 | | - pd.Index(x_points.coords[BREAKPOINT_DIM].values, name=BREAKPOINT_DIM), |
| 934 | + pd.Index(stacked_bp.coords[SEGMENT_DIM].values, name=SEGMENT_DIM), |
| 935 | + pd.Index(stacked_bp.coords[dim].values, name=dim), |
936 | 936 | ] |
937 | 937 | binary_coords = extra + [ |
938 | | - pd.Index(x_points.coords[SEGMENT_DIM].values, name=SEGMENT_DIM), |
| 938 | + pd.Index(stacked_bp.coords[SEGMENT_DIM].values, name=SEGMENT_DIM), |
939 | 939 | ] |
940 | 940 |
|
941 | | - binary_mask = mask.any(dim=BREAKPOINT_DIM) if mask is not None else None |
| 941 | + # Masks |
| 942 | + lambda_mask = None |
| 943 | + binary_mask = None |
| 944 | + if stacked_mask is not None: |
| 945 | + # Aggregate across link_dim — all variables must be valid |
| 946 | + agg_mask = stacked_mask.all(dim=link_dim) |
| 947 | + lambda_mask = agg_mask |
| 948 | + binary_mask = agg_mask.any(dim=dim) |
| 949 | + |
| 950 | + binary_name = f"{name}{PWL_BINARY_SUFFIX}" |
| 951 | + select_name = f"{name}{PWL_SELECT_SUFFIX}" |
| 952 | + lambda_name = f"{name}{PWL_LAMBDA_SUFFIX}" |
| 953 | + convex_name = f"{name}{PWL_CONVEX_SUFFIX}" |
| 954 | + link_name = f"{name}{PWL_X_LINK_SUFFIX}" |
942 | 955 |
|
943 | 956 | binary_var = model.add_variables( |
944 | 957 | binary=True, coords=binary_coords, name=binary_name, mask=binary_mask |
945 | 958 | ) |
946 | 959 |
|
947 | 960 | rhs = active if active is not None else 1 |
948 | | - select_con = model.add_constraints( |
| 961 | + model.add_constraints( |
949 | 962 | binary_var.sum(dim=SEGMENT_DIM) == rhs, name=select_name |
950 | 963 | ) |
951 | 964 |
|
952 | 965 | lambda_var = model.add_variables( |
953 | | - lower=0, upper=1, coords=lambda_coords, name=lambda_name, mask=mask |
| 966 | + lower=0, upper=1, coords=lambda_coords, name=lambda_name, mask=lambda_mask |
954 | 967 | ) |
955 | 968 |
|
956 | | - model.add_sos_constraints(lambda_var, sos_type=2, sos_dim=BREAKPOINT_DIM) |
| 969 | + model.add_sos_constraints(lambda_var, sos_type=2, sos_dim=dim) |
957 | 970 |
|
958 | 971 | model.add_constraints( |
959 | | - lambda_var.sum(dim=BREAKPOINT_DIM) == binary_var, name=convex_name |
| 972 | + lambda_var.sum(dim=dim) == binary_var, name=convex_name |
960 | 973 | ) |
961 | 974 |
|
962 | | - x_weighted = (lambda_var * x_points).sum(dim=[SEGMENT_DIM, BREAKPOINT_DIM]) |
963 | | - model.add_constraints(x_expr == x_weighted, name=x_link_name) |
964 | | - |
965 | | - y_weighted = (lambda_var * y_points).sum(dim=[SEGMENT_DIM, BREAKPOINT_DIM]) |
966 | | - model.add_constraints(y_expr == y_weighted, name=y_link_name) |
967 | | - |
968 | | - return select_con |
| 975 | + weighted = (lambda_var * stacked_bp).sum(dim=[SEGMENT_DIM, dim]) |
| 976 | + return model.add_constraints(target_expr == weighted, name=link_name) |
0 commit comments