|
14 | 14 |
|
15 | 15 |
|
16 | 16 | int main(int argc, char **argv) { |
17 | | - CMDParser parser; |
18 | | - parser.parse_command_line(argc, argv); |
19 | | - DataSet train_dataset; |
20 | | - train_dataset.load_from_file(parser.svmtrain_input_file_name); |
21 | | - SvmModel *model = nullptr; |
22 | | - switch (parser.param_cmd.svm_type) { |
23 | | - case SvmParam::C_SVC: |
24 | | - model = new SVC(); |
25 | | - break; |
26 | | - case SvmParam::NU_SVC: |
27 | | - model = new NuSVC(); |
28 | | - break; |
29 | | - case SvmParam::ONE_CLASS: |
30 | | - model = new OneClassSVC(); |
31 | | - break; |
32 | | - case SvmParam::EPSILON_SVR: |
33 | | - model = new SVR(); |
34 | | - break; |
35 | | - case SvmParam::NU_SVR: |
36 | | - model = new NuSVR(); |
37 | | - break; |
38 | | - } |
| 17 | + try { |
| 18 | + CMDParser parser; |
| 19 | + parser.parse_command_line(argc, argv); |
| 20 | + DataSet train_dataset; |
| 21 | + train_dataset.load_from_file(parser.svmtrain_input_file_name); |
| 22 | + std::shared_ptr<SvmModel> model; |
| 23 | + switch (parser.param_cmd.svm_type) { |
| 24 | + case SvmParam::C_SVC: |
| 25 | + model = new SVC(); |
| 26 | + break; |
| 27 | + case SvmParam::NU_SVC: |
| 28 | + model = new NuSVC(); |
| 29 | + break; |
| 30 | + case SvmParam::ONE_CLASS: |
| 31 | + model = new OneClassSVC(); |
| 32 | + break; |
| 33 | + case SvmParam::EPSILON_SVR: |
| 34 | + model = new SVR(); |
| 35 | + break; |
| 36 | + case SvmParam::NU_SVR: |
| 37 | + model = new NuSVR(); |
| 38 | + break; |
| 39 | + } |
39 | 40 |
|
40 | | - //todo add this to check_parameter method |
41 | | - if (parser.param_cmd.svm_type == SvmParam::NU_SVC) { |
42 | | - train_dataset.group_classes(); |
43 | | - for (int i = 0; i < train_dataset.n_classes(); ++i) { |
44 | | - int n1 = train_dataset.count()[i]; |
45 | | - for (int j = i + 1; j < train_dataset.n_classes(); ++j) { |
46 | | - int n2 = train_dataset.count()[j]; |
47 | | - if (parser.param_cmd.nu * (n1 + n2) / 2 > min(n1, n2)) { |
48 | | - printf("specified nu is infeasible\n"); |
49 | | - return 1; |
| 41 | + //todo add this to check_parameter method |
| 42 | + if (parser.param_cmd.svm_type == SvmParam::NU_SVC) { |
| 43 | + train_dataset.group_classes(); |
| 44 | + for (int i = 0; i < train_dataset.n_classes(); ++i) { |
| 45 | + int n1 = train_dataset.count()[i]; |
| 46 | + for (int j = i + 1; j < train_dataset.n_classes(); ++j) { |
| 47 | + int n2 = train_dataset.count()[j]; |
| 48 | + if (parser.param_cmd.nu * (n1 + n2) / 2 > min(n1, n2)) { |
| 49 | + printf("specified nu is infeasible\n"); |
| 50 | + return 1; |
| 51 | + } |
50 | 52 | } |
51 | 53 | } |
52 | 54 | } |
53 | | - } |
54 | 55 |
|
55 | 56 | #ifdef USE_CUDA |
56 | | - CUDA_CHECK(cudaSetDevice(parser.gpu_id)); |
| 57 | + CUDA_CHECK(cudaSetDevice(parser.gpu_id)); |
57 | 58 | #endif |
58 | 59 |
|
59 | | - vector<float_type> predict_y; |
60 | | - if (parser.do_cross_validation) { |
61 | | - predict_y = model->cross_validation(train_dataset, parser.param_cmd, parser.nr_fold); |
62 | | - } else { |
63 | | - model->train(train_dataset, parser.param_cmd); |
64 | | - model->save_to_file(parser.model_file_name); |
65 | | - predict_y = model->predict(train_dataset.instances(), 10000); |
66 | | - } |
67 | | - |
68 | | - //perform svm testing |
69 | | - Metric *metric = nullptr; |
70 | | - switch (parser.param_cmd.svm_type) { |
71 | | - case SvmParam::C_SVC: |
72 | | - case SvmParam::NU_SVC: { |
73 | | - metric = new Accuracy(); |
74 | | - break; |
| 60 | + vector<float_type> predict_y; |
| 61 | + if (parser.do_cross_validation) { |
| 62 | + predict_y = model->cross_validation(train_dataset, parser.param_cmd, parser.nr_fold); |
| 63 | + } else { |
| 64 | + model->train(train_dataset, parser.param_cmd); |
| 65 | + model->save_to_file(parser.model_file_name); |
| 66 | + predict_y = model->predict(train_dataset.instances(), 10000); |
75 | 67 | } |
76 | | - case SvmParam::EPSILON_SVR: |
77 | | - case SvmParam::NU_SVR: { |
78 | | - metric = new MSE(); |
79 | | - break; |
| 68 | + |
| 69 | + //perform svm testing |
| 70 | + std::shared_ptr<Metric> metric; |
| 71 | + switch (parser.param_cmd.svm_type) { |
| 72 | + case SvmParam::C_SVC: |
| 73 | + case SvmParam::NU_SVC: { |
| 74 | + metric = new Accuracy(); |
| 75 | + break; |
| 76 | + } |
| 77 | + case SvmParam::EPSILON_SVR: |
| 78 | + case SvmParam::NU_SVR: { |
| 79 | + metric = new MSE(); |
| 80 | + break; |
| 81 | + } |
| 82 | + case SvmParam::ONE_CLASS: { |
| 83 | + } |
80 | 84 | } |
81 | | - case SvmParam::ONE_CLASS: { |
| 85 | + if (metric) { |
| 86 | + LOG(INFO) << metric->name() << " = " << metric->score(predict_y, train_dataset.y()); |
82 | 87 | } |
83 | 88 | } |
84 | | - if (metric) { |
85 | | - LOG(INFO) << metric->name() << " = " << metric->score(predict_y, train_dataset.y()); |
| 89 | + catch (std::bad_alloc &) { |
| 90 | + LOG(FATAL) << "out of host memory"; |
| 91 | + exit(EXIT_FAILURE); |
| 92 | + } |
| 93 | + catch (std::exception const &x) { |
| 94 | + LOG(FATAL) << x.what(); |
| 95 | + exit(EXIT_FAILURE); |
| 96 | + } |
| 97 | + catch (...) { |
| 98 | + LOG(FATAL) << "unknown error"; |
| 99 | + exit(EXIT_FAILURE); |
86 | 100 | } |
87 | | - delete model; |
88 | | - delete metric; |
89 | 101 | } |
90 | 102 |
|
0 commit comments