|
30 | 30 | from cvxpy.reductions.solvers.nlp_solvers.diff_engine.registry import ATOM_CONVERTERS |
31 | 31 |
|
32 | 32 |
|
33 | | -def _matmul_normalize_1d(A, side): |
34 | | - """Reshape a 1D numpy array to 2D for matmul. |
| 33 | +def _matmul_param_node(arg, child, param_dict): |
| 34 | + """param_node for the constant side of a matmul. |
35 | 35 |
|
36 | | - NumPy matmul treats 1D arrays differently depending on which side: |
37 | | - Left 1D: (k,) → (1, k) — row vector |
38 | | - Right 1D: (k,) → (k, 1) — column vector |
39 | | - 2D input is returned unchanged. |
| 36 | + Returns the parameter capsule for a bare Parameter, the child capsule |
| 37 | + when the constant side contains parameters wrapped in an affine atom |
| 38 | + (so update_params keeps them live in the DAG), or None otherwise. |
40 | 39 | """ |
41 | | - if A.ndim == 1: |
42 | | - return A.reshape(1, -1) if side == 'left' else A.reshape(-1, 1) |
43 | | - return A |
| 40 | + if isinstance(arg, cp.Parameter): |
| 41 | + return param_dict[arg.id] |
| 42 | + if arg.parameters(): |
| 43 | + return child |
| 44 | + return None |
44 | 45 |
|
45 | 46 |
|
46 | 47 | def convert_matmul(expr, children, var_dict, n_vars, param_dict): |
47 | 48 | """Convert matrix multiplication A @ f(x), f(x) @ A, or X @ Y. |
48 | 49 |
|
49 | | - NumPy matmul semantics for 1D arrays: |
50 | | - (n,) @ (m,k) → treat left as (1,n) — normalize_shape already does this |
51 | | - (m,k) @ (n,) → treat right as (n,1) — must reshape from (1,n) storage |
52 | | - (n,) @ (n,) → dot product: (1,n) @ (n,1) → scalar |
53 | | -
|
54 | | - The C engine only has 2D nodes. 1D expressions are stored as (1,n) by |
55 | | - normalize_shape. All 1D→2D matmul normalization is handled here so that |
56 | | - helper functions always receive properly shaped 2D data. |
| 50 | + 1D operands are stored as (1, n) in the C engine. Left 1D stays (1, n); |
| 51 | + right 1D must be reshaped to (n, 1) for matmul. |
57 | 52 | """ |
58 | 53 | left_arg, right_arg = expr.args |
59 | 54 | left_child, right_child = children |
60 | 55 |
|
61 | | - # Right 1D child: C stores as (1, n) but matmul needs (n, 1). |
62 | | - # Do this once, before branching — used by all three branches. |
63 | 56 | if len(right_arg.shape) <= 1 and right_arg.size > 1: |
64 | 57 | right_child = _diffengine.make_reshape(right_child, right_arg.size, 1) |
65 | 58 |
|
66 | 59 | if left_arg.is_constant(): |
67 | | - A = _matmul_normalize_1d(left_arg.value, 'left') |
68 | | - if isinstance(left_arg, cp.Parameter): |
69 | | - param_node = param_dict[left_arg.id] |
70 | | - elif left_arg.parameters(): |
71 | | - param_node = left_child |
72 | | - else: |
73 | | - param_node = None |
| 60 | + A = left_arg.value |
| 61 | + if A.ndim == 1: |
| 62 | + A = A.reshape(1, -1) |
| 63 | + param_node = _matmul_param_node(left_arg, left_child, param_dict) |
74 | 64 | if sparse.issparse(A): |
75 | 65 | return make_sparse_left_matmul(param_node, right_child, A) |
76 | 66 | return make_dense_left_matmul(param_node, right_child, A) |
77 | 67 |
|
78 | | - elif right_arg.is_constant(): |
79 | | - A = _matmul_normalize_1d(right_arg.value, 'right') |
80 | | - if isinstance(right_arg, cp.Parameter): |
81 | | - param_node = param_dict[right_arg.id] |
82 | | - elif right_arg.parameters(): |
83 | | - param_node = right_child |
84 | | - else: |
85 | | - param_node = None |
| 68 | + if right_arg.is_constant(): |
| 69 | + A = right_arg.value |
| 70 | + if A.ndim == 1: |
| 71 | + A = A.reshape(-1, 1) |
| 72 | + param_node = _matmul_param_node(right_arg, right_child, param_dict) |
86 | 73 | if sparse.issparse(A): |
87 | 74 | return make_sparse_right_matmul(param_node, left_child, A) |
88 | 75 | return make_dense_right_matmul(param_node, left_child, A) |
89 | 76 |
|
90 | | - else: |
91 | | - return _diffengine.make_matmul(left_child, right_child) |
| 77 | + return _diffengine.make_matmul(left_child, right_child) |
92 | 78 |
|
93 | 79 | # TODO we should support sparse elementwise multiply at some point. |
94 | 80 | def convert_multiply(expr, children, var_dict, n_vars, param_dict): |
@@ -156,9 +142,7 @@ def convert_expr(expr, var_dict, n_vars, param_dict=None): |
156 | 142 | d1_Python, d2_Python = normalize_shape(expr.shape) |
157 | 143 |
|
158 | 144 | if d1_C != d1_Python or d2_C != d2_Python: |
159 | | - # 1D Python shapes (n,) normalize to (1, n), but the C engine may |
160 | | - # produce (n, 1) — e.g. matrix @ scalar or transpose of a vector. |
161 | | - # Both represent the same 1D data; reshape to match Python convention. |
| 145 | + # 1D shape (n,) normalizes to (1, n) but C may produce (n, 1); reshape. |
162 | 146 | if len(expr.shape) <= 1 and d1_C * d2_C == d1_Python * d2_Python: |
163 | 147 | C_expr = _diffengine.make_reshape(C_expr, d1_Python, d2_Python) |
164 | 148 | else: |
|
0 commit comments