Skip to content

Commit 8c37dc3

Browse files
committed
add try-catch in main() #31
1 parent 1a89920 commit 8c37dc3

File tree

2 files changed

+126
-100
lines changed

2 files changed

+126
-100
lines changed

src/thundersvm/thundersvm-predict.cpp

Lines changed: 55 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -14,52 +14,66 @@
1414
using std::fstream;
1515

1616
int main(int argc, char **argv) {
17-
CMDParser parser;
18-
parser.parse_command_line(argc, argv);
19-
fstream file;
20-
file.open(parser.svmpredict_model_file_name, fstream::in);
21-
std::cout << parser.svmpredict_model_file_name << "\n";
22-
string feature, svm_type;
23-
file >> feature >> svm_type;
24-
std::cout << feature << "; " << svm_type << "\n";
25-
CHECK_EQ(feature, "svm_type");
26-
SvmModel *model = nullptr;
27-
Metric *metric = nullptr;
28-
if (svm_type == "c_svc") {
29-
model = new SVC();
30-
metric = new Accuracy();
31-
} else if (svm_type == "nu_svc") {
32-
model = new NuSVC();
33-
metric = new Accuracy();
34-
} else if (svm_type == "one_class") {
35-
model = new OneClassSVC();
36-
//todo determine a metric
37-
} else if (svm_type == "epsilon_svr") {
38-
model = new SVR();
39-
metric = new MSE();
40-
} else if (svm_type == "nu_svr") {
41-
model = new NuSVR();
42-
metric = new MSE();
43-
}
17+
try {
18+
CMDParser parser;
19+
parser.parse_command_line(argc, argv);
20+
fstream file;
21+
file.open(parser.svmpredict_model_file_name, fstream::in);
22+
std::cout << parser.svmpredict_model_file_name << "\n";
23+
string feature, svm_type;
24+
file >> feature >> svm_type;
25+
std::cout << feature << "; " << svm_type << "\n";
26+
CHECK_EQ(feature, "svm_type");
27+
std::shared_ptr<SvmModel> model;
28+
std::shared_ptr<Metric> metric;
29+
if (svm_type == "c_svc") {
30+
model = new SVC();
31+
metric = new Accuracy();
32+
} else if (svm_type == "nu_svc") {
33+
model = new NuSVC();
34+
metric = new Accuracy();
35+
} else if (svm_type == "one_class") {
36+
model = new OneClassSVC();
37+
//todo determine a metric
38+
} else if (svm_type == "epsilon_svr") {
39+
model = new SVR();
40+
metric = new MSE();
41+
} else if (svm_type == "nu_svr") {
42+
model = new NuSVR();
43+
metric = new MSE();
44+
}
4445

4546
#ifdef USE_CUDA
46-
CUDA_CHECK(cudaSetDevice(parser.gpu_id));
47+
CUDA_CHECK(cudaSetDevice(parser.gpu_id));
4748
#endif
4849

49-
model->load_from_file(parser.svmpredict_model_file_name);
50-
file.close();
51-
file.open(parser.svmpredict_output_file);
52-
DataSet predict_dataset;
53-
predict_dataset.load_from_file(parser.svmpredict_input_file);
50+
model->load_from_file(parser.svmpredict_model_file_name);
51+
file.close();
52+
file.open(parser.svmpredict_output_file);
53+
DataSet predict_dataset;
54+
predict_dataset.load_from_file(parser.svmpredict_input_file);
5455

55-
vector<float_type> predict_y;
56-
predict_y = model->predict(predict_dataset.instances(), 10000);
57-
for (int i = 0; i < predict_y.size(); ++i) {
58-
file << predict_y[i] << std::endl;
59-
}
60-
file.close();
56+
vector<float_type> predict_y;
57+
predict_y = model->predict(predict_dataset.instances(), 10000);
58+
for (int i = 0; i < predict_y.size(); ++i) {
59+
file << predict_y[i] << std::endl;
60+
}
61+
file.close();
6162

62-
if (metric) {
63-
LOG(INFO) << metric->name() << " = " << metric->score(predict_y, predict_dataset.y());
63+
if (metric) {
64+
LOG(INFO) << metric->name() << " = " << metric->score(predict_y, predict_dataset.y());
65+
}
66+
}
67+
catch (std::bad_alloc &) {
68+
LOG(FATAL) << "out of host memory";
69+
exit(EXIT_FAILURE);
70+
}
71+
catch (std::exception const &x) {
72+
LOG(FATAL) << x.what();
73+
exit(EXIT_FAILURE);
74+
}
75+
catch (...) {
76+
LOG(FATAL) << "unknown error";
77+
exit(EXIT_FAILURE);
6478
}
6579
}

src/thundersvm/thundersvm-train.cpp

Lines changed: 71 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -14,77 +14,89 @@
1414

1515

1616
int main(int argc, char **argv) {
17-
CMDParser parser;
18-
parser.parse_command_line(argc, argv);
19-
DataSet train_dataset;
20-
train_dataset.load_from_file(parser.svmtrain_input_file_name);
21-
SvmModel *model = nullptr;
22-
switch (parser.param_cmd.svm_type) {
23-
case SvmParam::C_SVC:
24-
model = new SVC();
25-
break;
26-
case SvmParam::NU_SVC:
27-
model = new NuSVC();
28-
break;
29-
case SvmParam::ONE_CLASS:
30-
model = new OneClassSVC();
31-
break;
32-
case SvmParam::EPSILON_SVR:
33-
model = new SVR();
34-
break;
35-
case SvmParam::NU_SVR:
36-
model = new NuSVR();
37-
break;
38-
}
17+
try {
18+
CMDParser parser;
19+
parser.parse_command_line(argc, argv);
20+
DataSet train_dataset;
21+
train_dataset.load_from_file(parser.svmtrain_input_file_name);
22+
std::shared_ptr<SvmModel> model;
23+
switch (parser.param_cmd.svm_type) {
24+
case SvmParam::C_SVC:
25+
model = new SVC();
26+
break;
27+
case SvmParam::NU_SVC:
28+
model = new NuSVC();
29+
break;
30+
case SvmParam::ONE_CLASS:
31+
model = new OneClassSVC();
32+
break;
33+
case SvmParam::EPSILON_SVR:
34+
model = new SVR();
35+
break;
36+
case SvmParam::NU_SVR:
37+
model = new NuSVR();
38+
break;
39+
}
3940

40-
//todo add this to check_parameter method
41-
if (parser.param_cmd.svm_type == SvmParam::NU_SVC) {
42-
train_dataset.group_classes();
43-
for (int i = 0; i < train_dataset.n_classes(); ++i) {
44-
int n1 = train_dataset.count()[i];
45-
for (int j = i + 1; j < train_dataset.n_classes(); ++j) {
46-
int n2 = train_dataset.count()[j];
47-
if (parser.param_cmd.nu * (n1 + n2) / 2 > min(n1, n2)) {
48-
printf("specified nu is infeasible\n");
49-
return 1;
41+
//todo add this to check_parameter method
42+
if (parser.param_cmd.svm_type == SvmParam::NU_SVC) {
43+
train_dataset.group_classes();
44+
for (int i = 0; i < train_dataset.n_classes(); ++i) {
45+
int n1 = train_dataset.count()[i];
46+
for (int j = i + 1; j < train_dataset.n_classes(); ++j) {
47+
int n2 = train_dataset.count()[j];
48+
if (parser.param_cmd.nu * (n1 + n2) / 2 > min(n1, n2)) {
49+
printf("specified nu is infeasible\n");
50+
return 1;
51+
}
5052
}
5153
}
5254
}
53-
}
5455

5556
#ifdef USE_CUDA
56-
CUDA_CHECK(cudaSetDevice(parser.gpu_id));
57+
CUDA_CHECK(cudaSetDevice(parser.gpu_id));
5758
#endif
5859

59-
vector<float_type> predict_y;
60-
if (parser.do_cross_validation) {
61-
predict_y = model->cross_validation(train_dataset, parser.param_cmd, parser.nr_fold);
62-
} else {
63-
model->train(train_dataset, parser.param_cmd);
64-
model->save_to_file(parser.model_file_name);
65-
predict_y = model->predict(train_dataset.instances(), 10000);
66-
}
67-
68-
//perform svm testing
69-
Metric *metric = nullptr;
70-
switch (parser.param_cmd.svm_type) {
71-
case SvmParam::C_SVC:
72-
case SvmParam::NU_SVC: {
73-
metric = new Accuracy();
74-
break;
60+
vector<float_type> predict_y;
61+
if (parser.do_cross_validation) {
62+
predict_y = model->cross_validation(train_dataset, parser.param_cmd, parser.nr_fold);
63+
} else {
64+
model->train(train_dataset, parser.param_cmd);
65+
model->save_to_file(parser.model_file_name);
66+
predict_y = model->predict(train_dataset.instances(), 10000);
7567
}
76-
case SvmParam::EPSILON_SVR:
77-
case SvmParam::NU_SVR: {
78-
metric = new MSE();
79-
break;
68+
69+
//perform svm testing
70+
std::shared_ptr<Metric> metric;
71+
switch (parser.param_cmd.svm_type) {
72+
case SvmParam::C_SVC:
73+
case SvmParam::NU_SVC: {
74+
metric = new Accuracy();
75+
break;
76+
}
77+
case SvmParam::EPSILON_SVR:
78+
case SvmParam::NU_SVR: {
79+
metric = new MSE();
80+
break;
81+
}
82+
case SvmParam::ONE_CLASS: {
83+
}
8084
}
81-
case SvmParam::ONE_CLASS: {
85+
if (metric) {
86+
LOG(INFO) << metric->name() << " = " << metric->score(predict_y, train_dataset.y());
8287
}
8388
}
84-
if (metric) {
85-
LOG(INFO) << metric->name() << " = " << metric->score(predict_y, train_dataset.y());
89+
catch (std::bad_alloc &) {
90+
LOG(FATAL) << "out of host memory";
91+
exit(EXIT_FAILURE);
92+
}
93+
catch (std::exception const &x) {
94+
LOG(FATAL) << x.what();
95+
exit(EXIT_FAILURE);
96+
}
97+
catch (...) {
98+
LOG(FATAL) << "unknown error";
99+
exit(EXIT_FAILURE);
86100
}
87-
delete model;
88-
delete metric;
89101
}
90102

0 commit comments

Comments
 (0)