@@ -15,176 +15,172 @@ class QPSolvers(Enum):
1515 CVXPY = 2
1616
1717
18- class QPFunction (Function ):
19- def __init__ (self , eps = 1e-12 , verbose = 0 , notImprovedLim = 3 ,
18+ def QPFunction (eps = 1e-12 , verbose = 0 , notImprovedLim = 3 ,
2019 maxIter = 20 , solver = QPSolvers .PDIPM_BATCHED ,
2120 check_Q_spd = True ):
22- self .eps = eps
23- self .verbose = verbose
24- self .notImprovedLim = notImprovedLim
25- self .maxIter = maxIter
26- self .solver = solver
27- self .check_Q_spd = check_Q_spd
28-
29- def forward (self , Q_ , p_ , G_ , h_ , A_ , b_ ):
30- """Solve a batch of QPs.
31-
32- This function solves a batch of QPs, each optimizing over
33- `nz` variables and having `nineq` inequality constraints
34- and `neq` equality constraints.
35- The optimization problem for each instance in the batch
36- (dropping indexing from the notation) is of the form
37-
38- \hat z = argmin_z 1/2 z^T Q z + p^T z
39- subject to Gz <= h
40- Az = b
41-
42- where Q \in S^{nz,nz},
43- S^{nz,nz} is the set of all positive semi-definite matrices,
44- p \in R^{nz}
45- G \in R^{nineq,nz}
46- h \in R^{nineq}
47- A \in R^{neq,nz}
48- b \in R^{neq}
49-
50- These parameters should all be passed to this function as
51- Variable- or Parameter-wrapped Tensors.
52- (See torch.autograd.Variable and torch.nn.parameter.Parameter)
53-
54- If you want to solve a batch of QPs where `nz`, `nineq` and `neq`
55- are the same, but some of the contents differ across the
56- minibatch, you can pass in tensors in the standard way
57- where the first dimension indicates the batch example.
58- This can be done with some or all of the coefficients.
59-
60- You do not need to add an extra dimension to coefficients
61- that will not change across all of the minibatch examples.
62- This function is able to infer such cases.
63-
64- If you don't want to use any equality or inequality constraints,
65- you can set the appropriate values to:
66-
67- e = Variable(torch.Tensor())
68-
69- Parameters:
70- Q: A (nBatch, nz, nz) or (nz, nz) Tensor.
71- p: A (nBatch, nz) or (nz) Tensor.
72- G: A (nBatch, nineq, nz) or (nineq, nz) Tensor.
73- h: A (nBatch, nineq) or (nineq) Tensor.
74- A: A (nBatch, neq, nz) or (neq, nz) Tensor.
75- b: A (nBatch, neq) or (neq) Tensor.
76-
77- Returns: \hat z: a (nBatch, nz) Tensor.
78- """
79- nBatch = extract_nBatch (Q_ , p_ , G_ , h_ , A_ , b_ )
80- Q , _ = expandParam (Q_ , nBatch , 3 )
81- p , _ = expandParam (p_ , nBatch , 2 )
82- G , _ = expandParam (G_ , nBatch , 3 )
83- h , _ = expandParam (h_ , nBatch , 2 )
84- A , _ = expandParam (A_ , nBatch , 3 )
85- b , _ = expandParam (b_ , nBatch , 2 )
86-
87- if self .check_Q_spd :
88- for i in range (nBatch ):
89- e , _ = torch .eig (Q [i ])
90- if not torch .all (e [:,0 ] > 0 ):
91- raise RuntimeError ('Q is not SPD.' )
92-
93- _ , nineq , nz = G .size ()
94- neq = A .size (1 ) if A .nelement () > 0 else 0
95- assert (neq > 0 or nineq > 0 )
96- self .neq , self .nineq , self .nz = neq , nineq , nz
97-
98- if self .solver == QPSolvers .PDIPM_BATCHED :
99- self .Q_LU , self .S_LU , self .R = pdipm_b .pre_factor_kkt (Q , G , A )
100- zhats , self .nus , self .lams , self .slacks = pdipm_b .forward (
101- Q , p , G , h , A , b , self .Q_LU , self .S_LU , self .R ,
102- self .eps , self .verbose , self .notImprovedLim , self .maxIter )
103- elif self .solver == QPSolvers .CVXPY :
104- vals = torch .Tensor (nBatch ).type_as (Q )
105- zhats = torch .Tensor (nBatch , self .nz ).type_as (Q )
106- lams = torch .Tensor (nBatch , self .nineq ).type_as (Q )
107- nus = torch .Tensor (nBatch , self .neq ).type_as (Q ) \
108- if self .neq > 0 else torch .Tensor ()
109- slacks = torch .Tensor (nBatch , self .nineq ).type_as (Q )
110- for i in range (nBatch ):
111- Ai , bi = (A [i ], b [i ]) if neq > 0 else (None , None )
112- vals [i ], zhati , nui , lami , si = solvers .cvxpy .forward_single_np (
113- * [x .cpu ().numpy () if x is not None else None
114- for x in (Q [i ], p [i ], G [i ], h [i ], Ai , bi )])
115- # if zhati[0] is None:
116- # import IPython, sys; IPython.embed(); sys.exit(-1)
117- zhats [i ] = torch .Tensor (zhati )
118- lams [i ] = torch .Tensor (lami )
119- slacks [i ] = torch .Tensor (si )
120- if neq > 0 :
121- nus [i ] = torch .Tensor (nui )
122-
123- self .vals = vals
124- self .lams = lams
125- self .nus = nus
126- self .slacks = slacks
127- else :
128- assert False
129-
130- self .save_for_backward (zhats , Q_ , p_ , G_ , h_ , A_ , b_ )
131- return zhats
132-
133- def backward (self , dl_dzhat ):
134- zhats , Q , p , G , h , A , b = self .saved_tensors
135- nBatch = extract_nBatch (Q , p , G , h , A , b )
136- Q , Q_e = expandParam (Q , nBatch , 3 )
137- p , p_e = expandParam (p , nBatch , 2 )
138- G , G_e = expandParam (G , nBatch , 3 )
139- h , h_e = expandParam (h , nBatch , 2 )
140- A , A_e = expandParam (A , nBatch , 3 )
141- b , b_e = expandParam (b , nBatch , 2 )
142-
143- # neq, nineq, nz = self.neq, self.nineq, self.nz
144- neq , nineq = self .neq , self .nineq
145-
146-
147- if self .solver == QPSolvers .CVXPY :
148- self .Q_LU , self .S_LU , self .R = pdipm_b .pre_factor_kkt (Q , G , A )
149-
150- # Clamp here to avoid issues coming up when the slacks are too small.
151- # TODO: A better fix would be to get lams and slacks from the
152- # solver that don't have this issue.
153- d = torch .clamp (self .lams , min = 1e-8 ) / torch .clamp (self .slacks , min = 1e-8 )
154-
155- pdipm_b .factor_kkt (self .S_LU , self .R , d )
156- dx , _ , dlam , dnu = pdipm_b .solve_kkt (
157- self .Q_LU , d , G , A , self .S_LU ,
158- dl_dzhat , torch .zeros (nBatch , nineq ).type_as (G ),
159- torch .zeros (nBatch , nineq ).type_as (G ),
160- torch .zeros (nBatch , neq ).type_as (G ) if neq > 0 else torch .Tensor ())
161-
162- dps = dx
163- dGs = bger (dlam , zhats ) + bger (self .lams , dx )
164- if G_e :
165- dGs = dGs .mean (0 )
166- dhs = - dlam
167- if h_e :
168- dhs = dhs .mean (0 )
169- if neq > 0 :
170- dAs = bger (dnu , zhats ) + bger (self .nus , dx )
171- dbs = - dnu
172- if A_e :
173- dAs = dAs .mean (0 )
174- if b_e :
175- dbs = dbs .mean (0 )
176- else :
177- dAs , dbs = None , None
178- dQs = 0.5 * (bger (dx , zhats ) + bger (zhats , dx ))
179- if Q_e :
180- dQs = dQs .mean (0 )
181- if p_e :
182- dps = dps .mean (0 )
183-
184-
185- grads = (dQs , dps , dGs , dhs , dAs , dbs )
186-
187- return grads
21+ class QPFunctionFn (Function ):
22+ @staticmethod
23+ def forward (ctx , Q_ , p_ , G_ , h_ , A_ , b_ ):
24+ """Solve a batch of QPs.
25+
26+ This function solves a batch of QPs, each optimizing over
27+ `nz` variables and having `nineq` inequality constraints
28+ and `neq` equality constraints.
29+ The optimization problem for each instance in the batch
30+ (dropping indexing from the notation) is of the form
31+
32+ \hat z = argmin_z 1/2 z^T Q z + p^T z
33+ subject to Gz <= h
34+ Az = b
35+
36+ where Q \in S^{nz,nz},
37+ S^{nz,nz} is the set of all positive semi-definite matrices,
38+ p \in R^{nz}
39+ G \in R^{nineq,nz}
40+ h \in R^{nineq}
41+ A \in R^{neq,nz}
42+ b \in R^{neq}
43+
44+ These parameters should all be passed to this function as
45+ Variable- or Parameter-wrapped Tensors.
46+ (See torch.autograd.Variable and torch.nn.parameter.Parameter)
47+
48+ If you want to solve a batch of QPs where `nz`, `nineq` and `neq`
49+ are the same, but some of the contents differ across the
50+ minibatch, you can pass in tensors in the standard way
51+ where the first dimension indicates the batch example.
52+ This can be done with some or all of the coefficients.
53+
54+ You do not need to add an extra dimension to coefficients
55+ that will not change across all of the minibatch examples.
56+ This function is able to infer such cases.
57+
58+ If you don't want to use any equality or inequality constraints,
59+ you can set the appropriate values to:
60+
61+ e = Variable(torch.Tensor())
62+
63+ Parameters:
64+ Q: A (nBatch, nz, nz) or (nz, nz) Tensor.
65+ p: A (nBatch, nz) or (nz) Tensor.
66+ G: A (nBatch, nineq, nz) or (nineq, nz) Tensor.
67+ h: A (nBatch, nineq) or (nineq) Tensor.
68+ A: A (nBatch, neq, nz) or (neq, nz) Tensor.
69+ b: A (nBatch, neq) or (neq) Tensor.
70+
71+ Returns: \hat z: a (nBatch, nz) Tensor.
72+ """
73+ nBatch = extract_nBatch (Q_ , p_ , G_ , h_ , A_ , b_ )
74+ Q , _ = expandParam (Q_ , nBatch , 3 )
75+ p , _ = expandParam (p_ , nBatch , 2 )
76+ G , _ = expandParam (G_ , nBatch , 3 )
77+ h , _ = expandParam (h_ , nBatch , 2 )
78+ A , _ = expandParam (A_ , nBatch , 3 )
79+ b , _ = expandParam (b_ , nBatch , 2 )
80+
81+ if check_Q_spd :
82+ for i in range (nBatch ):
83+ e , _ = torch .eig (Q [i ])
84+ if not torch .all (e [:,0 ] > 0 ):
85+ raise RuntimeError ('Q is not SPD.' )
86+
87+ _ , nineq , nz = G .size ()
88+ neq = A .size (1 ) if A .nelement () > 0 else 0
89+ assert (neq > 0 or nineq > 0 )
90+ ctx .neq , ctx .nineq , ctx .nz = neq , nineq , nz
91+
92+ if solver == QPSolvers .PDIPM_BATCHED :
93+ ctx .Q_LU , ctx .S_LU , ctx .R = pdipm_b .pre_factor_kkt (Q , G , A )
94+ zhats , ctx .nus , ctx .lams , ctx .slacks = pdipm_b .forward (
95+ Q , p , G , h , A , b , ctx .Q_LU , ctx .S_LU , ctx .R ,
96+ eps , verbose , notImprovedLim , maxIter )
97+ elif solver == QPSolvers .CVXPY :
98+ vals = torch .Tensor (nBatch ).type_as (Q )
99+ zhats = torch .Tensor (nBatch , ctx .nz ).type_as (Q )
100+ lams = torch .Tensor (nBatch , ctx .nineq ).type_as (Q )
101+ nus = torch .Tensor (nBatch , ctx .neq ).type_as (Q ) \
102+ if ctx .neq > 0 else torch .Tensor ()
103+ slacks = torch .Tensor (nBatch , ctx .nineq ).type_as (Q )
104+ for i in range (nBatch ):
105+ Ai , bi = (A [i ], b [i ]) if neq > 0 else (None , None )
106+ vals [i ], zhati , nui , lami , si = solvers .cvxpy .forward_single_np (
107+ * [x .cpu ().numpy () if x is not None else None
108+ for x in (Q [i ], p [i ], G [i ], h [i ], Ai , bi )])
109+ # if zhati[0] is None:
110+ # import IPython, sys; IPython.embed(); sys.exit(-1)
111+ zhats [i ] = torch .Tensor (zhati )
112+ lams [i ] = torch .Tensor (lami )
113+ slacks [i ] = torch .Tensor (si )
114+ if neq > 0 :
115+ nus [i ] = torch .Tensor (nui )
116+
117+ ctx .vals = vals
118+ ctx .lams = lams
119+ ctx .nus = nus
120+ ctx .slacks = slacks
121+ else :
122+ assert False
123+
124+ ctx .save_for_backward (zhats , Q_ , p_ , G_ , h_ , A_ , b_ )
125+ return zhats
126+
127+ @staticmethod
128+ def backward (ctx , dl_dzhat ):
129+ zhats , Q , p , G , h , A , b = ctx .saved_tensors
130+ nBatch = extract_nBatch (Q , p , G , h , A , b )
131+ Q , Q_e = expandParam (Q , nBatch , 3 )
132+ p , p_e = expandParam (p , nBatch , 2 )
133+ G , G_e = expandParam (G , nBatch , 3 )
134+ h , h_e = expandParam (h , nBatch , 2 )
135+ A , A_e = expandParam (A , nBatch , 3 )
136+ b , b_e = expandParam (b , nBatch , 2 )
137+
138+ # neq, nineq, nz = ctx.neq, ctx.nineq, ctx.nz
139+ neq , nineq = ctx .neq , ctx .nineq
140+
141+
142+ if solver == QPSolvers .CVXPY :
143+ ctx .Q_LU , ctx .S_LU , ctx .R = pdipm_b .pre_factor_kkt (Q , G , A )
144+
145+ # Clamp here to avoid issues coming up when the slacks are too small.
146+ # TODO: A better fix would be to get lams and slacks from the
147+ # solver that don't have this issue.
148+ d = torch .clamp (ctx .lams , min = 1e-8 ) / torch .clamp (ctx .slacks , min = 1e-8 )
149+
150+ pdipm_b .factor_kkt (ctx .S_LU , ctx .R , d )
151+ dx , _ , dlam , dnu = pdipm_b .solve_kkt (
152+ ctx .Q_LU , d , G , A , ctx .S_LU ,
153+ dl_dzhat , torch .zeros (nBatch , nineq ).type_as (G ),
154+ torch .zeros (nBatch , nineq ).type_as (G ),
155+ torch .zeros (nBatch , neq ).type_as (G ) if neq > 0 else torch .Tensor ())
156+
157+ dps = dx
158+ dGs = bger (dlam , zhats ) + bger (ctx .lams , dx )
159+ if G_e :
160+ dGs = dGs .mean (0 )
161+ dhs = - dlam
162+ if h_e :
163+ dhs = dhs .mean (0 )
164+ if neq > 0 :
165+ dAs = bger (dnu , zhats ) + bger (ctx .nus , dx )
166+ dbs = - dnu
167+ if A_e :
168+ dAs = dAs .mean (0 )
169+ if b_e :
170+ dbs = dbs .mean (0 )
171+ else :
172+ dAs , dbs = None , None
173+ dQs = 0.5 * (bger (dx , zhats ) + bger (zhats , dx ))
174+ if Q_e :
175+ dQs = dQs .mean (0 )
176+ if p_e :
177+ dps = dps .mean (0 )
178+
179+
180+ grads = (dQs , dps , dGs , dhs , dAs , dbs )
181+
182+ return grads
183+ return QPFunctionFn .apply
188184
189185
190186class SpQPFunction (Function ):
0 commit comments