Skip to content

Commit d2d9f3e

Browse files
author
Zeyi Wen
authored
Merge pull request #5 from shijiashuai/master
added probability training and prediction
2 parents 829459a + 400892e commit d2d9f3e

File tree

3 files changed

+281
-31
lines changed

3 files changed

+281
-31
lines changed

mascot/svmModel.cu

Lines changed: 256 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include <helper_cuda.h>
1212
#include <cuda_runtime_api.h>
1313
#include "trainingFunction.h"
14+
1415
unsigned int svmModel::getK(int i, int j) const {
1516
return ((nrClass - 1) + (nrClass - i)) * i / 2 + j - i - 1;
1617
}
@@ -24,6 +25,7 @@ void svmModel::fit(const svmProblem &problem, const SVMParam &param) {
2425
probB.clear();
2526
supportVectors.clear();
2627
label.clear();
28+
probability = false;
2729

2830
coef.resize(cnr2);
2931
rho.resize(cnr2);
@@ -33,16 +35,140 @@ void svmModel::fit(const svmProblem &problem, const SVMParam &param) {
3335

3436
this->param = param;
3537
label = problem.label;
38+
int k = 0;
3639
for (int i = 0; i < nrClass; ++i) {
3740
for (int j = i + 1; j < nrClass; ++j) {
3841
svmProblem subProblem = problem.getSubProblem(i, j);
3942
printf("training classifier with label %d and %d\n", i, j);
43+
if (param.probability) {
44+
SVMParam probParam = param;
45+
probParam.probability = 0;
46+
probParam.C = 1.0;
47+
svmModel model;
48+
model.fit(subProblem, probParam);
49+
vector<float_point *> decValues;
50+
//todo predict with cross validation
51+
model.predictValues(subProblem.v_vSamples, decValues);
52+
//binary model has only one sub-model
53+
sigmoidTrain(decValues.front(), subProblem.getNumOfSamples(), subProblem.v_nLabels, probA[k], probB[k]);
54+
probability = true;
55+
}
4056
svm_model binaryModel = trainBinarySVM(subProblem, param);
41-
addBinaryModel(subProblem, binaryModel,i,j);
57+
addBinaryModel(subProblem, binaryModel, i, j);
58+
k++;
4259
}
4360
}
4461
}
4562

63+
void svmModel::sigmoidTrain(const float_point *decValues, const int l, const vector<int> &labels, float_point &A,
64+
float_point &B) {
65+
double prior1 = 0, prior0 = 0;
66+
int i;
67+
68+
for (i = 0; i < l; i++)
69+
if (labels[i] > 0)
70+
prior1 += 1;
71+
else
72+
prior0 += 1;
73+
74+
int max_iter = 100; // Maximal number of iterations
75+
double min_step = 1e-10; // Minimal step taken in line search
76+
double sigma = 1e-12; // For numerically strict PD of Hessian
77+
double eps = 1e-5;
78+
double hiTarget = (prior1 + 1.0) / (prior1 + 2.0);
79+
double loTarget = 1 / (prior0 + 2.0);
80+
double *t = (double *) malloc(sizeof(double) * l);
81+
double fApB, p, q, h11, h22, h21, g1, g2, det, dA, dB, gd, stepsize;
82+
double newA, newB, newf, d1, d2;
83+
int iter;
84+
85+
// Initial Point and Initial Fun Value
86+
A = 0.0;
87+
B = log((prior0 + 1.0) / (prior1 + 1.0));
88+
double fval = 0.0;
89+
90+
for (i = 0; i < l; i++) {
91+
if (labels[i] > 0)
92+
t[i] = hiTarget;
93+
else
94+
t[i] = loTarget;
95+
fApB = decValues[i] * A + B;
96+
if (fApB >= 0)
97+
fval += t[i] * fApB + log(1 + exp(-fApB));
98+
else
99+
fval += (t[i] - 1) * fApB + log(1 + exp(fApB));
100+
}
101+
for (iter = 0; iter < max_iter; iter++) {
102+
// Update Gradient and Hessian (use H' = H + sigma I)
103+
h11 = sigma; // numerically ensures strict PD
104+
h22 = sigma;
105+
h21 = 0.0;
106+
g1 = 0.0;
107+
g2 = 0.0;
108+
for (i = 0; i < l; i++) {
109+
fApB = decValues[i] * A + B;
110+
if (fApB >= 0) {
111+
p = exp(-fApB) / (1.0 + exp(-fApB));
112+
q = 1.0 / (1.0 + exp(-fApB));
113+
} else {
114+
p = 1.0 / (1.0 + exp(fApB));
115+
q = exp(fApB) / (1.0 + exp(fApB));
116+
}
117+
d2 = p * q;
118+
h11 += decValues[i] * decValues[i] * d2;
119+
h22 += d2;
120+
h21 += decValues[i] * d2;
121+
d1 = t[i] - p;
122+
g1 += decValues[i] * d1;
123+
g2 += d1;
124+
}
125+
126+
// Stopping Criteria
127+
if (fabs(g1) < eps && fabs(g2) < eps)
128+
break;
129+
130+
// Finding Newton direction: -inv(H') * g
131+
det = h11 * h22 - h21 * h21;
132+
dA = -(h22 * g1 - h21 * g2) / det;
133+
dB = -(-h21 * g1 + h11 * g2) / det;
134+
gd = g1 * dA + g2 * dB;
135+
136+
stepsize = 1; // Line Search
137+
while (stepsize >= min_step) {
138+
newA = A + stepsize * dA;
139+
newB = B + stepsize * dB;
140+
141+
// New function value
142+
newf = 0.0;
143+
for (i = 0; i < l; i++) {
144+
fApB = decValues[i] * newA + newB;
145+
if (fApB >= 0)
146+
newf += t[i] * fApB + log(1 + exp(-fApB));
147+
else
148+
newf += (t[i] - 1) * fApB + log(1 + exp(fApB));
149+
}
150+
// Check sufficient decrease
151+
if (newf < fval + 0.0001 * stepsize * gd) {
152+
A = newA;
153+
B = newB;
154+
fval = newf;
155+
break;
156+
} else
157+
stepsize = stepsize / 2.0;
158+
}
159+
160+
if (stepsize < min_step) {
161+
printf("Line search fails in two-class probability estimates\n");
162+
break;
163+
}
164+
}
165+
166+
if (iter >= max_iter)
167+
printf(
168+
"Reaching maximal iterations in two-class probability estimates\n");
169+
free(t);
170+
}
171+
46172
void svmModel::addBinaryModel(const svmProblem &problem, const svm_model &bModel, int i, int j) {
47173
unsigned int k = getK(i, j);
48174
for (int l = 0; l < bModel.nSV[0] + bModel.nSV[1]; ++l) {
@@ -55,41 +181,146 @@ void svmModel::addBinaryModel(const svmProblem &problem, const svm_model &bModel
55181
rho[k] = bModel.rho[0];
56182
}
57183

58-
vector<float_point*> svmModel::predictValues(const vector<vector<float_point> > &v_vSamples) const {
59-
vector<float_point *> decisionValues(cnr2);
184+
void
185+
svmModel::predictValues(const vector<vector<float_point> > &v_vSamples, vector<float_point *> &decisionValues) const {
186+
decisionValues.clear();
60187
for (int k = 0; k < cnr2; ++k) {
61-
float_point *kernelValues = new float_point[v_vSamples.size() * supportVectors[k].size()];
62-
computeKernelValuesOnFly(v_vSamples, supportVectors[k], kernelValues);
63-
decisionValues[k] = predictLabels(kernelValues, (int) v_vSamples.size(), k);
64-
delete[] kernelValues;
65-
}
66-
return decisionValues;
188+
float_point *kernelValues = new float_point[v_vSamples.size() * supportVectors[k].size()];
189+
computeKernelValuesOnFly(v_vSamples, supportVectors[k], kernelValues);
190+
decisionValues.push_back(
191+
predictLabels(kernelValues, (int) v_vSamples.size(), k));//TODO not return local pointer in function
192+
delete[] kernelValues;
193+
}
67194
}
68195

69-
vector<int> svmModel::predict(const vector<vector<float_point> > &v_vSamples) const {
70-
vector<float_point*> decisionValues = predictValues(v_vSamples);
196+
vector<int> svmModel::predict(const vector<vector<float_point> > &v_vSamples, bool probability) const {
197+
vector<float_point *> decisionValues;
198+
predictValues(v_vSamples, decisionValues);
71199
vector<int> labels;
72-
for (int l = 0; l < v_vSamples.size(); ++l) {
73-
vector<int> votes(nrClass,0);
74-
int k = 0;
75-
for (int i = 0; i < nrClass; ++i) {
76-
for (int j = i+1; j < nrClass; ++j) {
77-
if(decisionValues[k++][l]>0)
78-
votes[i]++;
79-
else
80-
votes[j]++;
200+
if (!probability) {
201+
for (int l = 0; l < v_vSamples.size(); ++l) {
202+
vector<int> votes(nrClass, 0);
203+
int k = 0;
204+
for (int i = 0; i < nrClass; ++i) {
205+
for (int j = i + 1; j < nrClass; ++j) {
206+
if (decisionValues[k++][l] > 0)
207+
votes[i]++;
208+
else
209+
votes[j]++;
210+
}
81211
}
212+
int maxVoteClass = 0;
213+
for (int i = 0; i < nrClass; ++i) {
214+
if (votes[i] > votes[maxVoteClass])
215+
maxVoteClass = i;
216+
}
217+
labels.push_back(this->label[maxVoteClass]);
82218
}
83-
int maxVoteClass = 0;
84-
for (int i = 0; i < nrClass; ++i) {
85-
if (votes[i]>votes[maxVoteClass])
86-
maxVoteClass = i;
219+
} else {
220+
printf("predict with probability\n");
221+
assert(this->probability);
222+
vector<vector<float_point> > prob = predictProbability(v_vSamples);
223+
// todo select max using GPU
224+
for (int i = 0; i < v_vSamples.size(); ++i) {
225+
int maxProbClass = 0;
226+
for (int j = 0; j < nrClass; ++j) {
227+
if (prob[i][j] > prob[i][maxProbClass])
228+
maxProbClass = j;
229+
}
230+
labels.push_back(this->label[maxProbClass]);
87231
}
88-
labels.push_back(this->label[maxVoteClass]);
89232
}
90233
return labels;
91234
}
92235

236+
float_point svmModel::sigmoid_predict(float_point decValue, float_point A, float_point B) const {
237+
double fApB = decValue * A + B;
238+
// 1-p used later; avoid catastrophic cancellation
239+
if (fApB >= 0)
240+
return exp(-fApB) / (1.0 + exp(-fApB));
241+
else
242+
return 1.0 / (1 + exp(fApB));
243+
}
244+
245+
void svmModel::multiclass_probability(const vector<vector<float_point> > &r, vector<float_point> &p) const {
246+
int t, j;
247+
int iter = 0, max_iter = max(100, nrClass);
248+
double **Q = (double **) malloc(sizeof(double *) * nrClass);
249+
double *Qp = (double *) malloc(sizeof(double) * nrClass);
250+
double pQp, eps = 0.005 / nrClass;
251+
252+
for (t = 0; t < nrClass; t++) {
253+
p[t] = 1.0 / nrClass; // Valid if k = 1
254+
Q[t] = (double *) malloc(sizeof(double) * nrClass);
255+
Q[t][t] = 0;
256+
for (j = 0; j < t; j++) {
257+
Q[t][t] += r[j][t] * r[j][t];
258+
Q[t][j] = Q[j][t];
259+
}
260+
for (j = t + 1; j < nrClass; j++) {
261+
Q[t][t] += r[j][t] * r[j][t];
262+
Q[t][j] = -r[j][t] * r[t][j];
263+
}
264+
}
265+
for (iter = 0; iter < max_iter; iter++) {
266+
// stopping condition, recalculate QP,pQP for numerical accuracy
267+
pQp = 0;
268+
for (t = 0; t < nrClass; t++) {
269+
Qp[t] = 0;
270+
for (j = 0; j < nrClass; j++)
271+
Qp[t] += Q[t][j] * p[j];
272+
pQp += p[t] * Qp[t];
273+
}
274+
double max_error = 0;
275+
for (t = 0; t < nrClass; t++) {
276+
double error = fabs(Qp[t] - pQp);
277+
if (error > max_error)
278+
max_error = error;
279+
}
280+
if (max_error < eps)
281+
break;
282+
283+
for (t = 0; t < nrClass; t++) {
284+
double diff = (-Qp[t] + pQp) / Q[t][t];
285+
p[t] += diff;
286+
pQp = (pQp + diff * (diff * Q[t][t] + 2 * Qp[t])) / (1 + diff)
287+
/ (1 + diff);
288+
for (j = 0; j < nrClass; j++) {
289+
Qp[j] = (Qp[j] + diff * Q[t][j]) / (1 + diff);
290+
p[j] /= (1 + diff);
291+
}
292+
}
293+
}
294+
if (iter >= max_iter)
295+
printf("Exceeds max_iter in multiclass_prob\n");
296+
for (t = 0; t < nrClass; t++)
297+
free(Q[t]);
298+
free(Q);
299+
free(Qp);
300+
}
301+
302+
vector<vector<float_point> > svmModel::predictProbability(const vector<vector<float_point> > &v_vSamples) const {
303+
vector<vector<float_point> > result;
304+
vector<float_point *> decValues;
305+
predictValues(v_vSamples, decValues);
306+
for (int l = 0; l < v_vSamples.size(); ++l) {
307+
vector<vector<float_point> > r(nrClass, vector<float_point>(nrClass));
308+
double min_prob = 1e-7;
309+
int k = 0;
310+
for (int i = 0; i < nrClass; i++)
311+
for (int j = i + 1; j < nrClass; j++) {
312+
r[i][j] = min(
313+
max(sigmoid_predict(decValues[k][l], probA[k], probB[k]), min_prob), 1 - min_prob);
314+
r[j][i] = 1 - r[i][j];
315+
k++;
316+
}
317+
vector<float_point> p(nrClass);
318+
multiclass_probability(r, p);
319+
result.push_back(p);
320+
}
321+
return result;
322+
}
323+
93324
void
94325
svmModel::computeKernelValuesOnFly(const vector<vector<float_point> > &samples,
95326
const vector<vector<float_point> > &supportVectors,

mascot/svmModel.h

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include "../svm-shared/gpu_global_utility.h"
1010
#include "svmProblem.h"
1111
#include "svmParam.h"
12+
1213
using std::vector;
1314

1415
class svmModel {
@@ -22,24 +23,37 @@ class svmModel {
2223
vector<float_point> probA;
2324
vector<float_point> probB;
2425
vector<int> label;
26+
bool probability;
2527

2628
unsigned int inline getK(int i, int j) const;
2729

28-
float_point* predictLabels(const float_point *kernelValues, int, int) const;
30+
float_point *predictLabels(const float_point *kernelValues, int, int) const;
2931

30-
float_point* ComputeClassLabel(int nNumofTestSamples,
32+
float_point *ComputeClassLabel(int nNumofTestSamples,
3133
float_point *pfDevSVYiAlphaHessian, const int &nNumofSVs,
3234
float_point fBias, float_point *pfFinalResult) const;
3335

3436
void computeKernelValuesOnFly(const vector<vector<float_point> > &samples,
3537
const vector<vector<float_point> > &supportVectors, float_point *kernelValues) const;
3638

3739
void addBinaryModel(const svmProblem &, const svm_model &, int i, int j);
40+
41+
float_point sigmoid_predict(float_point decValue, float_point A, float_point B) const;
42+
43+
void multiclass_probability(const vector<vector<float_point> > &, vector<float_point> &) const;
44+
45+
void
46+
sigmoidTrain(const float_point *decValues, const int, const vector<int> &labels, float_point &A, float_point &B);
47+
3848
public:
3949

40-
void fit(const svmProblem& problem, const SVMParam &param);
41-
vector<int> predict(const vector<vector<float_point> > &) const;
42-
vector<float_point* > predictValues(const vector<vector<float_point> >&) const;
50+
void fit(const svmProblem &problem, const SVMParam &param);
51+
52+
vector<int> predict(const vector<vector<float_point> > &, bool probability=false) const;
53+
54+
vector<vector<float_point> > predictProbability(const vector<vector<float_point> > &) const;
55+
56+
void predictValues(const vector<vector<float_point> > &, vector<float_point *> &) const;
4357
};
4458

4559

mascot/trainingFunction.cu

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,12 +47,17 @@ svmModel trainSVM(SVMParam &param, string strTrainingFileName, int nNumofFeature
4747
rawDataRead.ReadFromFile(strTrainingFileName, nNumofFeature, v_v_DocVector, v_nLabel);
4848
svmProblem problem(v_v_DocVector, v_nLabel);
4949
svmModel model;
50+
param.probability = 1;//train with probability
5051
model.fit(problem, param);
51-
vector<int> predictLabels = model.predict(v_v_DocVector);
52+
vector<int> predictLabels = model.predict(v_v_DocVector, true);
5253
int numOfCorrect = 0;
5354
for (int i = 0; i < v_v_DocVector.size(); ++i) {
5455
if (predictLabels[i] == v_nLabel[i])
5556
numOfCorrect++;
57+
// for (int j = 0; j < problem.getNumOfClasses(); ++j) {
58+
// printf("%.2f,",prob[i][j]);
59+
// }
60+
// printf("\n");
5661
}
5762
printf("training accuracy = %.2f%%(%d/%d)\n", numOfCorrect / (float) v_v_DocVector.size()*100, numOfCorrect,
5863
(int) v_v_DocVector.size());

0 commit comments

Comments
 (0)