Skip to content

Commit 315d5f9

Browse files
committed
Update to the new Function interface and update some byte->bool indexing
1 parent bb156fe commit 315d5f9

4 files changed

Lines changed: 169 additions & 173 deletions

File tree

qpth/qp.py

Lines changed: 164 additions & 168 deletions
Original file line numberDiff line numberDiff line change
@@ -15,176 +15,172 @@ class QPSolvers(Enum):
1515
CVXPY = 2
1616

1717

18-
class QPFunction(Function):
19-
def __init__(self, eps=1e-12, verbose=0, notImprovedLim=3,
18+
def QPFunction(eps=1e-12, verbose=0, notImprovedLim=3,
2019
maxIter=20, solver=QPSolvers.PDIPM_BATCHED,
2120
check_Q_spd=True):
22-
self.eps = eps
23-
self.verbose = verbose
24-
self.notImprovedLim = notImprovedLim
25-
self.maxIter = maxIter
26-
self.solver = solver
27-
self.check_Q_spd = check_Q_spd
28-
29-
def forward(self, Q_, p_, G_, h_, A_, b_):
30-
"""Solve a batch of QPs.
31-
32-
This function solves a batch of QPs, each optimizing over
33-
`nz` variables and having `nineq` inequality constraints
34-
and `neq` equality constraints.
35-
The optimization problem for each instance in the batch
36-
(dropping indexing from the notation) is of the form
37-
38-
\hat z = argmin_z 1/2 z^T Q z + p^T z
39-
subject to Gz <= h
40-
Az = b
41-
42-
where Q \in S^{nz,nz},
43-
S^{nz,nz} is the set of all positive semi-definite matrices,
44-
p \in R^{nz}
45-
G \in R^{nineq,nz}
46-
h \in R^{nineq}
47-
A \in R^{neq,nz}
48-
b \in R^{neq}
49-
50-
These parameters should all be passed to this function as
51-
Variable- or Parameter-wrapped Tensors.
52-
(See torch.autograd.Variable and torch.nn.parameter.Parameter)
53-
54-
If you want to solve a batch of QPs where `nz`, `nineq` and `neq`
55-
are the same, but some of the contents differ across the
56-
minibatch, you can pass in tensors in the standard way
57-
where the first dimension indicates the batch example.
58-
This can be done with some or all of the coefficients.
59-
60-
You do not need to add an extra dimension to coefficients
61-
that will not change across all of the minibatch examples.
62-
This function is able to infer such cases.
63-
64-
If you don't want to use any equality or inequality constraints,
65-
you can set the appropriate values to:
66-
67-
e = Variable(torch.Tensor())
68-
69-
Parameters:
70-
Q: A (nBatch, nz, nz) or (nz, nz) Tensor.
71-
p: A (nBatch, nz) or (nz) Tensor.
72-
G: A (nBatch, nineq, nz) or (nineq, nz) Tensor.
73-
h: A (nBatch, nineq) or (nineq) Tensor.
74-
A: A (nBatch, neq, nz) or (neq, nz) Tensor.
75-
b: A (nBatch, neq) or (neq) Tensor.
76-
77-
Returns: \hat z: a (nBatch, nz) Tensor.
78-
"""
79-
nBatch = extract_nBatch(Q_, p_, G_, h_, A_, b_)
80-
Q, _ = expandParam(Q_, nBatch, 3)
81-
p, _ = expandParam(p_, nBatch, 2)
82-
G, _ = expandParam(G_, nBatch, 3)
83-
h, _ = expandParam(h_, nBatch, 2)
84-
A, _ = expandParam(A_, nBatch, 3)
85-
b, _ = expandParam(b_, nBatch, 2)
86-
87-
if self.check_Q_spd:
88-
for i in range(nBatch):
89-
e, _ = torch.eig(Q[i])
90-
if not torch.all(e[:,0] > 0):
91-
raise RuntimeError('Q is not SPD.')
92-
93-
_, nineq, nz = G.size()
94-
neq = A.size(1) if A.nelement() > 0 else 0
95-
assert(neq > 0 or nineq > 0)
96-
self.neq, self.nineq, self.nz = neq, nineq, nz
97-
98-
if self.solver == QPSolvers.PDIPM_BATCHED:
99-
self.Q_LU, self.S_LU, self.R = pdipm_b.pre_factor_kkt(Q, G, A)
100-
zhats, self.nus, self.lams, self.slacks = pdipm_b.forward(
101-
Q, p, G, h, A, b, self.Q_LU, self.S_LU, self.R,
102-
self.eps, self.verbose, self.notImprovedLim, self.maxIter)
103-
elif self.solver == QPSolvers.CVXPY:
104-
vals = torch.Tensor(nBatch).type_as(Q)
105-
zhats = torch.Tensor(nBatch, self.nz).type_as(Q)
106-
lams = torch.Tensor(nBatch, self.nineq).type_as(Q)
107-
nus = torch.Tensor(nBatch, self.neq).type_as(Q) \
108-
if self.neq > 0 else torch.Tensor()
109-
slacks = torch.Tensor(nBatch, self.nineq).type_as(Q)
110-
for i in range(nBatch):
111-
Ai, bi = (A[i], b[i]) if neq > 0 else (None, None)
112-
vals[i], zhati, nui, lami, si = solvers.cvxpy.forward_single_np(
113-
*[x.cpu().numpy() if x is not None else None
114-
for x in (Q[i], p[i], G[i], h[i], Ai, bi)])
115-
# if zhati[0] is None:
116-
# import IPython, sys; IPython.embed(); sys.exit(-1)
117-
zhats[i] = torch.Tensor(zhati)
118-
lams[i] = torch.Tensor(lami)
119-
slacks[i] = torch.Tensor(si)
120-
if neq > 0:
121-
nus[i] = torch.Tensor(nui)
122-
123-
self.vals = vals
124-
self.lams = lams
125-
self.nus = nus
126-
self.slacks = slacks
127-
else:
128-
assert False
129-
130-
self.save_for_backward(zhats, Q_, p_, G_, h_, A_, b_)
131-
return zhats
132-
133-
def backward(self, dl_dzhat):
134-
zhats, Q, p, G, h, A, b = self.saved_tensors
135-
nBatch = extract_nBatch(Q, p, G, h, A, b)
136-
Q, Q_e = expandParam(Q, nBatch, 3)
137-
p, p_e = expandParam(p, nBatch, 2)
138-
G, G_e = expandParam(G, nBatch, 3)
139-
h, h_e = expandParam(h, nBatch, 2)
140-
A, A_e = expandParam(A, nBatch, 3)
141-
b, b_e = expandParam(b, nBatch, 2)
142-
143-
# neq, nineq, nz = self.neq, self.nineq, self.nz
144-
neq, nineq = self.neq, self.nineq
145-
146-
147-
if self.solver == QPSolvers.CVXPY:
148-
self.Q_LU, self.S_LU, self.R = pdipm_b.pre_factor_kkt(Q, G, A)
149-
150-
# Clamp here to avoid issues coming up when the slacks are too small.
151-
# TODO: A better fix would be to get lams and slacks from the
152-
# solver that don't have this issue.
153-
d = torch.clamp(self.lams, min=1e-8) / torch.clamp(self.slacks, min=1e-8)
154-
155-
pdipm_b.factor_kkt(self.S_LU, self.R, d)
156-
dx, _, dlam, dnu = pdipm_b.solve_kkt(
157-
self.Q_LU, d, G, A, self.S_LU,
158-
dl_dzhat, torch.zeros(nBatch, nineq).type_as(G),
159-
torch.zeros(nBatch, nineq).type_as(G),
160-
torch.zeros(nBatch, neq).type_as(G) if neq > 0 else torch.Tensor())
161-
162-
dps = dx
163-
dGs = bger(dlam, zhats) + bger(self.lams, dx)
164-
if G_e:
165-
dGs = dGs.mean(0)
166-
dhs = -dlam
167-
if h_e:
168-
dhs = dhs.mean(0)
169-
if neq > 0:
170-
dAs = bger(dnu, zhats) + bger(self.nus, dx)
171-
dbs = -dnu
172-
if A_e:
173-
dAs = dAs.mean(0)
174-
if b_e:
175-
dbs = dbs.mean(0)
176-
else:
177-
dAs, dbs = None, None
178-
dQs = 0.5 * (bger(dx, zhats) + bger(zhats, dx))
179-
if Q_e:
180-
dQs = dQs.mean(0)
181-
if p_e:
182-
dps = dps.mean(0)
183-
184-
185-
grads = (dQs, dps, dGs, dhs, dAs, dbs)
186-
187-
return grads
21+
class QPFunctionFn(Function):
22+
@staticmethod
23+
def forward(ctx, Q_, p_, G_, h_, A_, b_):
24+
"""Solve a batch of QPs.
25+
26+
This function solves a batch of QPs, each optimizing over
27+
`nz` variables and having `nineq` inequality constraints
28+
and `neq` equality constraints.
29+
The optimization problem for each instance in the batch
30+
(dropping indexing from the notation) is of the form
31+
32+
\hat z = argmin_z 1/2 z^T Q z + p^T z
33+
subject to Gz <= h
34+
Az = b
35+
36+
where Q \in S^{nz,nz},
37+
S^{nz,nz} is the set of all positive semi-definite matrices,
38+
p \in R^{nz}
39+
G \in R^{nineq,nz}
40+
h \in R^{nineq}
41+
A \in R^{neq,nz}
42+
b \in R^{neq}
43+
44+
These parameters should all be passed to this function as
45+
Variable- or Parameter-wrapped Tensors.
46+
(See torch.autograd.Variable and torch.nn.parameter.Parameter)
47+
48+
If you want to solve a batch of QPs where `nz`, `nineq` and `neq`
49+
are the same, but some of the contents differ across the
50+
minibatch, you can pass in tensors in the standard way
51+
where the first dimension indicates the batch example.
52+
This can be done with some or all of the coefficients.
53+
54+
You do not need to add an extra dimension to coefficients
55+
that will not change across all of the minibatch examples.
56+
This function is able to infer such cases.
57+
58+
If you don't want to use any equality or inequality constraints,
59+
you can set the appropriate values to:
60+
61+
e = Variable(torch.Tensor())
62+
63+
Parameters:
64+
Q: A (nBatch, nz, nz) or (nz, nz) Tensor.
65+
p: A (nBatch, nz) or (nz) Tensor.
66+
G: A (nBatch, nineq, nz) or (nineq, nz) Tensor.
67+
h: A (nBatch, nineq) or (nineq) Tensor.
68+
A: A (nBatch, neq, nz) or (neq, nz) Tensor.
69+
b: A (nBatch, neq) or (neq) Tensor.
70+
71+
Returns: \hat z: a (nBatch, nz) Tensor.
72+
"""
73+
nBatch = extract_nBatch(Q_, p_, G_, h_, A_, b_)
74+
Q, _ = expandParam(Q_, nBatch, 3)
75+
p, _ = expandParam(p_, nBatch, 2)
76+
G, _ = expandParam(G_, nBatch, 3)
77+
h, _ = expandParam(h_, nBatch, 2)
78+
A, _ = expandParam(A_, nBatch, 3)
79+
b, _ = expandParam(b_, nBatch, 2)
80+
81+
if check_Q_spd:
82+
for i in range(nBatch):
83+
e, _ = torch.eig(Q[i])
84+
if not torch.all(e[:,0] > 0):
85+
raise RuntimeError('Q is not SPD.')
86+
87+
_, nineq, nz = G.size()
88+
neq = A.size(1) if A.nelement() > 0 else 0
89+
assert(neq > 0 or nineq > 0)
90+
ctx.neq, ctx.nineq, ctx.nz = neq, nineq, nz
91+
92+
if solver == QPSolvers.PDIPM_BATCHED:
93+
ctx.Q_LU, ctx.S_LU, ctx.R = pdipm_b.pre_factor_kkt(Q, G, A)
94+
zhats, ctx.nus, ctx.lams, ctx.slacks = pdipm_b.forward(
95+
Q, p, G, h, A, b, ctx.Q_LU, ctx.S_LU, ctx.R,
96+
eps, verbose, notImprovedLim, maxIter)
97+
elif solver == QPSolvers.CVXPY:
98+
vals = torch.Tensor(nBatch).type_as(Q)
99+
zhats = torch.Tensor(nBatch, ctx.nz).type_as(Q)
100+
lams = torch.Tensor(nBatch, ctx.nineq).type_as(Q)
101+
nus = torch.Tensor(nBatch, ctx.neq).type_as(Q) \
102+
if ctx.neq > 0 else torch.Tensor()
103+
slacks = torch.Tensor(nBatch, ctx.nineq).type_as(Q)
104+
for i in range(nBatch):
105+
Ai, bi = (A[i], b[i]) if neq > 0 else (None, None)
106+
vals[i], zhati, nui, lami, si = solvers.cvxpy.forward_single_np(
107+
*[x.cpu().numpy() if x is not None else None
108+
for x in (Q[i], p[i], G[i], h[i], Ai, bi)])
109+
# if zhati[0] is None:
110+
# import IPython, sys; IPython.embed(); sys.exit(-1)
111+
zhats[i] = torch.Tensor(zhati)
112+
lams[i] = torch.Tensor(lami)
113+
slacks[i] = torch.Tensor(si)
114+
if neq > 0:
115+
nus[i] = torch.Tensor(nui)
116+
117+
ctx.vals = vals
118+
ctx.lams = lams
119+
ctx.nus = nus
120+
ctx.slacks = slacks
121+
else:
122+
assert False
123+
124+
ctx.save_for_backward(zhats, Q_, p_, G_, h_, A_, b_)
125+
return zhats
126+
127+
@staticmethod
128+
def backward(ctx, dl_dzhat):
129+
zhats, Q, p, G, h, A, b = ctx.saved_tensors
130+
nBatch = extract_nBatch(Q, p, G, h, A, b)
131+
Q, Q_e = expandParam(Q, nBatch, 3)
132+
p, p_e = expandParam(p, nBatch, 2)
133+
G, G_e = expandParam(G, nBatch, 3)
134+
h, h_e = expandParam(h, nBatch, 2)
135+
A, A_e = expandParam(A, nBatch, 3)
136+
b, b_e = expandParam(b, nBatch, 2)
137+
138+
# neq, nineq, nz = ctx.neq, ctx.nineq, ctx.nz
139+
neq, nineq = ctx.neq, ctx.nineq
140+
141+
142+
if solver == QPSolvers.CVXPY:
143+
ctx.Q_LU, ctx.S_LU, ctx.R = pdipm_b.pre_factor_kkt(Q, G, A)
144+
145+
# Clamp here to avoid issues coming up when the slacks are too small.
146+
# TODO: A better fix would be to get lams and slacks from the
147+
# solver that don't have this issue.
148+
d = torch.clamp(ctx.lams, min=1e-8) / torch.clamp(ctx.slacks, min=1e-8)
149+
150+
pdipm_b.factor_kkt(ctx.S_LU, ctx.R, d)
151+
dx, _, dlam, dnu = pdipm_b.solve_kkt(
152+
ctx.Q_LU, d, G, A, ctx.S_LU,
153+
dl_dzhat, torch.zeros(nBatch, nineq).type_as(G),
154+
torch.zeros(nBatch, nineq).type_as(G),
155+
torch.zeros(nBatch, neq).type_as(G) if neq > 0 else torch.Tensor())
156+
157+
dps = dx
158+
dGs = bger(dlam, zhats) + bger(ctx.lams, dx)
159+
if G_e:
160+
dGs = dGs.mean(0)
161+
dhs = -dlam
162+
if h_e:
163+
dhs = dhs.mean(0)
164+
if neq > 0:
165+
dAs = bger(dnu, zhats) + bger(ctx.nus, dx)
166+
dbs = -dnu
167+
if A_e:
168+
dAs = dAs.mean(0)
169+
if b_e:
170+
dbs = dbs.mean(0)
171+
else:
172+
dAs, dbs = None, None
173+
dQs = 0.5 * (bger(dx, zhats) + bger(zhats, dx))
174+
if Q_e:
175+
dQs = dQs.mean(0)
176+
if p_e:
177+
dps = dps.mean(0)
178+
179+
180+
grads = (dQs, dps, dGs, dhs, dAs, dbs)
181+
182+
return grads
183+
return QPFunctionFn.apply
188184

189185

190186
class SpQPFunction(Function):

qpth/solvers/pdipm/batch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -437,7 +437,7 @@ def factor_kkt(S_LU, R, d):
437437
if factor_kkt_eye is None or factor_kkt_eye.size() != d.size():
438438
# print('Updating batchedEye size.')
439439
factor_kkt_eye = torch.eye(nineq).repeat(
440-
nBatch, 1, 1).type_as(R).byte()
440+
nBatch, 1, 1).type_as(R).bool()
441441
T = R.clone()
442442
T[factor_kkt_eye] += (1. / d).squeeze().view(-1)
443443

qpth/util.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def get_sizes(G, A=None):
3636
def bdiag(d):
3737
nBatch, sz = d.size()
3838
D = torch.zeros(nBatch, sz, sz).type_as(d)
39-
I = torch.eye(sz).repeat(nBatch, 1, 1).type_as(d).byte()
39+
I = torch.eye(sz).repeat(nBatch, 1, 1).type_as(d).bool()
4040
D[I] = d.squeeze().view(-1)
4141
return D
4242

test.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import numpy as np
1414
import numpy.random as npr
1515
import numpy.testing as npt
16-
from numpy.testing import decorators
16+
from numpy.testing import dec
1717
np.set_printoptions(precision=6)
1818

1919
import numdifftools as nd
@@ -247,7 +247,7 @@ def test_ir_kkt_solver():
247247
npt.assert_allclose(dy.numpy(), dy_.numpy(), rtol=RTOL, atol=ATOL)
248248

249249

250-
@npt.decorators.skipif(
250+
@npt.dec.skipif(
251251
not torch.cuda.is_available() or not hasattr(torch, 'spbqrfactsolve'))
252252
def test_sparse_forward():
253253
torch.manual_seed(0)
@@ -300,7 +300,7 @@ def cast(m):
300300
xhats_qpf.cpu().numpy(), rtol=RTOL, atol=ATOL)
301301

302302

303-
@npt.decorators.skipif(
303+
@npt.dec.skipif(
304304
not torch.cuda.is_available() or not hasattr(torch, 'spbqrfactsolve'))
305305
def test_sparse_backward():
306306
torch.manual_seed(0)

0 commit comments

Comments
 (0)