Skip to content

Commit b1c4cab

Browse files
authored
Merge pull request #104 from DavidIkov/fix_for_inference
Fix for inference_evaluation and f_score
2 parents 1a5f79b + 2ef6e32 commit b1c4cab

9 files changed

Lines changed: 168 additions & 151 deletions

File tree

examples/mnist-learn/main.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
* limitations under the License.
2020
*/
2121

22-
#include <knp/framework/inference_evaluation/classification.h>
22+
#include <knp/framework/inference_evaluation/classification/processor.h>
2323

2424
#include <filesystem>
2525
#include <fstream>
@@ -80,7 +80,7 @@ int main(int argc, char** argv)
8080
std::cout << get_time_string() << ": inference finished -- output spike count is " << spikes.size() << std::endl;
8181

8282
// Evaluate results.
83-
inference_evaluation::InferenceResultForClass::InferenceResultsProcessor inference_processor;
83+
inference_evaluation::InferenceResultsProcessor inference_processor;
8484
inference_processor.process_inference_results(spikes, dataset);
8585

8686
inference_processor.write_inference_results_to_stream_as_csv(std::cout);

knp/base-framework/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,8 +73,8 @@ knp_add_library("${PROJECT_NAME}-core"
7373
impl/sonata/types/additive_delta_synapse.cpp
7474
impl/data_processing/classification/dataset.cpp
7575
impl/data_processing/classification/image.cpp
76-
impl/inference_evaluation/classification.cpp
7776
impl/inference_evaluation/perfomance_metrics.cpp
77+
impl/inference_evaluation/classification/processor.cpp
7878
impl/observer.cpp
7979
${${PROJECT_NAME}_headers}
8080
ALIAS KNP::BaseFramework::Core

knp/base-framework/impl/inference_evaluation/classification.cpp renamed to knp/base-framework/impl/inference_evaluation/classification/processor.cpp

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
*/
2121

2222
#include <knp/core/messaging/messaging.h>
23-
#include <knp/framework/inference_evaluation/classification.h>
23+
#include <knp/framework/inference_evaluation/classification/processor.h>
2424
#include <knp/framework/inference_evaluation/perfomance_metrics.h>
2525

2626
#include <algorithm>
@@ -30,14 +30,16 @@
3030
namespace knp::framework::inference_evaluation::classification
3131
{
3232

33-
class InferenceResultForClass::InferenceResultsProcessor::EvaluationHelper
33+
class EvaluationHelper
3434
{
3535
public:
36-
explicit EvaluationHelper(const knp::framework::data_processing::classification::Dataset &dataset);
36+
explicit EvaluationHelper(
37+
const knp::framework::data_processing::classification::Dataset &dataset,
38+
std::vector<InferenceResult> &inference_results);
3739

3840
void process_spikes(const knp::core::messaging::SpikeData &firing_neuron_indices, size_t step);
3941

40-
[[nodiscard]] std::vector<InferenceResultForClass> process_inference_predictions() const;
42+
[[nodiscard]] std::vector<InferenceResult> process_inference_predictions() const;
4143

4244
private:
4345
struct Prediction
@@ -53,18 +55,19 @@ class InferenceResultForClass::InferenceResultsProcessor::EvaluationHelper
5355
std::vector<size_t> class_votes_;
5456

5557
const knp::framework::data_processing::classification::Dataset &dataset_;
58+
std::vector<InferenceResult> &inference_results_;
5659
};
5760

5861

59-
InferenceResultForClass::InferenceResultsProcessor::EvaluationHelper::EvaluationHelper(
60-
const knp::framework::data_processing::classification::Dataset &dataset)
61-
: class_votes_(dataset.get_amount_of_classes(), 0), dataset_(dataset)
62+
EvaluationHelper::EvaluationHelper(
63+
const knp::framework::data_processing::classification::Dataset &dataset,
64+
std::vector<InferenceResult> &inference_results)
65+
: class_votes_(dataset.get_amount_of_classes(), 0), dataset_(dataset), inference_results_(inference_results)
6266
{
6367
}
6468

6569

66-
void InferenceResultForClass::InferenceResultsProcessor::EvaluationHelper::process_spikes(
67-
const knp::core::messaging::SpikeData &firing_neuron_indices, size_t step)
70+
void EvaluationHelper::process_spikes(const knp::core::messaging::SpikeData &firing_neuron_indices, size_t step)
6871
{
6972
for (auto i : firing_neuron_indices) ++class_votes_[i % dataset_.get_amount_of_classes()];
7073
if (!((step + 1) % dataset_.get_steps_per_frame()))
@@ -86,10 +89,9 @@ void InferenceResultForClass::InferenceResultsProcessor::EvaluationHelper::proce
8689
}
8790

8891

89-
std::vector<InferenceResultForClass>
90-
InferenceResultForClass::InferenceResultsProcessor::EvaluationHelper::process_inference_predictions() const
92+
std::vector<InferenceResult> EvaluationHelper::process_inference_predictions() const
9193
{
92-
std::vector<InferenceResultForClass> prediction_results(dataset_.get_amount_of_classes());
94+
std::vector<InferenceResult> prediction_results(dataset_.get_amount_of_classes());
9395
for (size_t i = 0; i < predictions_.size(); ++i)
9496
{
9597
auto const &prediction = predictions_[i];
@@ -113,11 +115,11 @@ InferenceResultForClass::InferenceResultsProcessor::EvaluationHelper::process_in
113115
}
114116

115117

116-
void InferenceResultForClass::InferenceResultsProcessor::process_inference_results(
118+
void InferenceResultsProcessor::process_inference_results(
117119
const std::vector<knp::core::messaging::SpikeMessage> &spikes,
118120
knp::framework::data_processing::classification::Dataset const &dataset)
119121
{
120-
EvaluationHelper helper(dataset);
122+
EvaluationHelper helper(dataset, inference_results_);
121123
knp::core::messaging::SpikeData firing_neuron_indices;
122124
auto spikes_iter = spikes.begin();
123125

@@ -137,11 +139,10 @@ void InferenceResultForClass::InferenceResultsProcessor::process_inference_resul
137139
}
138140

139141

140-
void InferenceResultForClass::InferenceResultsProcessor::write_inference_results_to_stream_as_csv(
141-
std::ostream &results_stream)
142+
void InferenceResultsProcessor::write_inference_results_to_stream_as_csv(std::ostream &results_stream)
142143
{
143144
results_stream << "CLASS,TOTAL_VOTES,TRUE_POSITIVES,FALSE_NEGATIVES,FALSE_POSITIVES,TRUE_NEGATIVES,PRECISION,"
144-
"RECALL,PREVALENCE,ACCURACY,F_MEASURE\n";
145+
"RECALL,PREVALENCE,ACCURACY,F_SCORE\n";
145146
for (size_t label = 0; label < inference_results_.size(); ++label)
146147
{
147148
auto const &prediction = inference_results_[label];
@@ -153,12 +154,12 @@ void InferenceResultForClass::InferenceResultsProcessor::write_inference_results
153154
const float accuracy = get_accuracy(
154155
prediction.true_positives_, prediction.false_negatives_, prediction.false_positives_,
155156
prediction.true_negatives_);
156-
const float f_measure = get_f_measure(precision, recall);
157+
const float f_score = get_f_score(precision, recall);
157158

158159
results_stream << label << ',' << prediction.get_total_votes() << ',' << prediction.true_positives_ << ','
159160
<< prediction.false_negatives_ << ',' << prediction.false_positives_ << ','
160161
<< prediction.true_negatives_ << ',' << precision << ',' << recall << ',' << prevalence << ','
161-
<< accuracy << ',' << f_measure << std::endl;
162+
<< accuracy << ',' << f_score << std::endl;
162163
}
163164
}
164165

knp/base-framework/impl/inference_evaluation/perfomance_metrics.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ float get_accuracy(size_t true_positives, size_t false_negatives, size_t false_p
5555
}
5656

5757

58-
float get_f_measure(float precision, float recall)
58+
float get_f_score(float precision, float recall)
5959
{
6060
if (precision * recall == 0) return 0.F;
6161
return 2.F * precision * recall / (precision + recall);

knp/base-framework/include/knp/framework/inference_evaluation/classification.h

Lines changed: 0 additions & 115 deletions
This file was deleted.
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
/**
2+
* @file processor.h
3+
* @brief Processing inference results.
4+
* @kaspersky_support D. Postnikov
5+
* @date 05.09.2025
6+
* @license Apache 2.0
7+
* @copyright © 2025 AO Kaspersky Lab
8+
*
9+
* Licensed under the Apache License, Version 2.0 (the "License");
10+
* you may not use this file except in compliance with the License.
11+
* You may obtain a copy of the License at
12+
*
13+
* http://www.apache.org/licenses/LICENSE-2.0
14+
*
15+
* Unless required by applicable law or agreed to in writing, software
16+
* distributed under the License is distributed on an "AS IS" BASIS,
17+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18+
* See the License for the specific language governing permissions and
19+
* limitations under the License.
20+
*/
21+
22+
#pragma once
23+
24+
#include <knp/core/messaging/messaging.h>
25+
#include <knp/framework/data_processing/classification/image.h>
26+
27+
#include <vector>
28+
29+
#include "result.h"
30+
31+
32+
namespace knp::framework::inference_evaluation::classification
33+
{
34+
35+
/**
36+
* @details A class to process inference results.
37+
*/
38+
class KNP_DECLSPEC InferenceResultsProcessor
39+
{
40+
public:
41+
/**
42+
* @brief Process inference results. Suited for classification models.
43+
* @param spikes All spikes from inference.
44+
* @param dataset Dataset.
45+
*/
46+
void process_inference_results(
47+
const std::vector<knp::core::messaging::SpikeMessage> &spikes,
48+
const knp::framework::data_processing::classification::Dataset &dataset);
49+
50+
/**
51+
* @brief Put inference results for each class to a stream in form of csv.
52+
* @param results_stream stream for output.
53+
*/
54+
void write_inference_results_to_stream_as_csv(std::ostream &results_stream);
55+
56+
/**
57+
* @brief Get inference results.
58+
* @return Inference results.
59+
*/
60+
[[nodiscard]] const std::vector<InferenceResult> &get_inference_results() const { return inference_results_; }
61+
62+
private:
63+
/**
64+
* @brief Processed inference results.
65+
*/
66+
std::vector<InferenceResult> inference_results_;
67+
};
68+
69+
} // namespace knp::framework::inference_evaluation::classification
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
/**
2+
* @file result.h
3+
* @brief Structure to hold inference results.
4+
* @kaspersky_support D. Postnikov
5+
* @date 16.07.2025
6+
* @license Apache 2.0
7+
* @copyright © 2025 AO Kaspersky Lab
8+
*
9+
* Licensed under the Apache License, Version 2.0 (the "License");
10+
* you may not use this file except in compliance with the License.
11+
* You may obtain a copy of the License at
12+
*
13+
* http://www.apache.org/licenses/LICENSE-2.0
14+
*
15+
* Unless required by applicable law or agreed to in writing, software
16+
* distributed under the License is distributed on an "AS IS" BASIS,
17+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18+
* See the License for the specific language governing permissions and
19+
* limitations under the License.
20+
*/
21+
22+
#pragma once
23+
24+
#include <knp/core/impexp.h>
25+
26+
namespace knp::framework::inference_evaluation::classification
27+
{
28+
29+
/**
30+
* @brief Processed inference result for single class.
31+
*/
32+
struct KNP_DECLSPEC InferenceResult
33+
{
34+
/**
35+
* @brief Amount of times model, that is supposed to predict dog, predicted dog when it is a dog.
36+
*/
37+
size_t true_positives_ = 0;
38+
39+
/**
40+
* @brief Amount of times model, that is supposed to predict dog, predicted not a dog when it is a dog.
41+
*/
42+
size_t false_negatives_ = 0;
43+
44+
/**
45+
* @brief Amount of times model, that is supposed to predict dog, predicted dog when it is not a dog.
46+
*/
47+
size_t false_positives_ = 0;
48+
49+
/**
50+
* @brief Amount of times model, that is supposed to predict dog, predicted not a dog when it is a not a dog.
51+
*/
52+
size_t true_negatives_ = 0;
53+
54+
/**
55+
* @brief Shortcut for getting total votes.
56+
* @return Total votes.
57+
*/
58+
[[nodiscard]] size_t get_total_votes() const { return true_positives_ + false_negatives_ + false_positives_; }
59+
};
60+
61+
} // namespace knp::framework::inference_evaluation::classification

0 commit comments

Comments
 (0)