2020 */
2121
2222#include < knp/core/messaging/messaging.h>
23- #include < knp/framework/inference_evaluation/classification.h>
23+ #include < knp/framework/inference_evaluation/classification/processor .h>
2424#include < knp/framework/inference_evaluation/perfomance_metrics.h>
2525
2626#include < algorithm>
3030namespace knp ::framework::inference_evaluation::classification
3131{
3232
33- class InferenceResultForClass ::InferenceResultsProcessor:: EvaluationHelper
33+ class EvaluationHelper
3434{
3535public:
36- explicit EvaluationHelper (const knp::framework::data_processing::classification::Dataset &dataset);
36+ explicit EvaluationHelper (
37+ const knp::framework::data_processing::classification::Dataset &dataset,
38+ std::vector<InferenceResult> &inference_results);
3739
3840 void process_spikes (const knp::core::messaging::SpikeData &firing_neuron_indices, size_t step);
3941
40- [[nodiscard]] std::vector<InferenceResultForClass > process_inference_predictions () const ;
42+ [[nodiscard]] std::vector<InferenceResult > process_inference_predictions () const ;
4143
4244private:
4345 struct Prediction
@@ -53,18 +55,19 @@ class InferenceResultForClass::InferenceResultsProcessor::EvaluationHelper
5355 std::vector<size_t > class_votes_;
5456
5557 const knp::framework::data_processing::classification::Dataset &dataset_;
58+ std::vector<InferenceResult> &inference_results_;
5659};
5760
5861
59- InferenceResultForClass::InferenceResultsProcessor::EvaluationHelper::EvaluationHelper (
60- const knp::framework::data_processing::classification::Dataset &dataset)
61- : class_votes_(dataset.get_amount_of_classes(), 0 ), dataset_(dataset)
62+ EvaluationHelper::EvaluationHelper (
63+ const knp::framework::data_processing::classification::Dataset &dataset,
64+ std::vector<InferenceResult> &inference_results)
65+ : class_votes_(dataset.get_amount_of_classes(), 0 ), dataset_(dataset), inference_results_(inference_results)
6266{
6367}
6468
6569
66- void InferenceResultForClass::InferenceResultsProcessor::EvaluationHelper::process_spikes (
67- const knp::core::messaging::SpikeData &firing_neuron_indices, size_t step)
70+ void EvaluationHelper::process_spikes (const knp::core::messaging::SpikeData &firing_neuron_indices, size_t step)
6871{
6972 for (auto i : firing_neuron_indices) ++class_votes_[i % dataset_.get_amount_of_classes ()];
7073 if (!((step + 1 ) % dataset_.get_steps_per_frame ()))
@@ -86,10 +89,9 @@ void InferenceResultForClass::InferenceResultsProcessor::EvaluationHelper::proce
8689}
8790
8891
89- std::vector<InferenceResultForClass>
90- InferenceResultForClass::InferenceResultsProcessor::EvaluationHelper::process_inference_predictions () const
92+ std::vector<InferenceResult> EvaluationHelper::process_inference_predictions () const
9193{
92- std::vector<InferenceResultForClass > prediction_results (dataset_.get_amount_of_classes ());
94+ std::vector<InferenceResult > prediction_results (dataset_.get_amount_of_classes ());
9395 for (size_t i = 0 ; i < predictions_.size (); ++i)
9496 {
9597 auto const &prediction = predictions_[i];
@@ -113,11 +115,11 @@ InferenceResultForClass::InferenceResultsProcessor::EvaluationHelper::process_in
113115}
114116
115117
116- void InferenceResultForClass:: InferenceResultsProcessor::process_inference_results (
118+ void InferenceResultsProcessor::process_inference_results (
117119 const std::vector<knp::core::messaging::SpikeMessage> &spikes,
118120 knp::framework::data_processing::classification::Dataset const &dataset)
119121{
120- EvaluationHelper helper (dataset);
122+ EvaluationHelper helper (dataset, inference_results_ );
121123 knp::core::messaging::SpikeData firing_neuron_indices;
122124 auto spikes_iter = spikes.begin ();
123125
@@ -137,11 +139,10 @@ void InferenceResultForClass::InferenceResultsProcessor::process_inference_resul
137139}
138140
139141
140- void InferenceResultForClass::InferenceResultsProcessor::write_inference_results_to_stream_as_csv (
141- std::ostream &results_stream)
142+ void InferenceResultsProcessor::write_inference_results_to_stream_as_csv (std::ostream &results_stream)
142143{
143144 results_stream << " CLASS,TOTAL_VOTES,TRUE_POSITIVES,FALSE_NEGATIVES,FALSE_POSITIVES,TRUE_NEGATIVES,PRECISION,"
144- " RECALL,PREVALENCE,ACCURACY,F_MEASURE \n " ;
145+ " RECALL,PREVALENCE,ACCURACY,F_SCORE \n " ;
145146 for (size_t label = 0 ; label < inference_results_.size (); ++label)
146147 {
147148 auto const &prediction = inference_results_[label];
@@ -153,12 +154,12 @@ void InferenceResultForClass::InferenceResultsProcessor::write_inference_results
153154 const float accuracy = get_accuracy (
154155 prediction.true_positives_ , prediction.false_negatives_ , prediction.false_positives_ ,
155156 prediction.true_negatives_ );
156- const float f_measure = get_f_measure (precision, recall);
157+ const float f_score = get_f_score (precision, recall);
157158
158159 results_stream << label << ' ,' << prediction.get_total_votes () << ' ,' << prediction.true_positives_ << ' ,'
159160 << prediction.false_negatives_ << ' ,' << prediction.false_positives_ << ' ,'
160161 << prediction.true_negatives_ << ' ,' << precision << ' ,' << recall << ' ,' << prevalence << ' ,'
161- << accuracy << ' ,' << f_measure << std::endl;
162+ << accuracy << ' ,' << f_score << std::endl;
162163 }
163164}
164165
0 commit comments