11#pragma once
22
3+ #include < onnxruntime/core/session/onnxruntime_cxx_api.h>
4+
35#include < memory>
6+ #include < opencv2/core.hpp>
7+ #include < rclcpp/logger.hpp>
48#include < string>
59#include < unordered_map>
610#include < vector>
711
8- #include < onnxruntime/core/session/onnxruntime_cxx_api.h>
9- #include < opencv2/core.hpp>
10- #include < rclcpp/logger.hpp>
11-
1212#include " bitbots_vision/candidate.hpp"
1313#include " bitbots_vision/model_config.hpp"
1414#include " bitbots_vision/yoeo_processing.hpp"
@@ -25,44 +25,40 @@ namespace bitbots_vision {
2525// / - two outputs:
2626// / [0] detections shape [1, N, 5+num_classes] (x_c, y_c, w, h, obj_conf, class_probs…)
2727// / [1] segmentation shape [1, H_seg, W_seg] (argmax class index per pixel)
28- class YoeoHandler
29- {
30- public:
28+ class YoeoHandler {
29+ public:
3130 struct Config {
3231 float conf_threshold{0 .5f };
3332 float nms_threshold{0 .4f };
3433 };
3534
36- YoeoHandler (
37- const std::string & model_path,
38- const ModelConfig & model_config,
39- const Config & cfg,
40- const rclcpp::Logger & logger);
35+ YoeoHandler (const std::string& model_path, const ModelConfig& model_config, const Config& cfg,
36+ const rclcpp::Logger& logger);
4137
4238 // / Update thresholds without reloading the model.
43- void reconfigure (const Config & cfg);
39+ void reconfigure (const Config& cfg);
4440
4541 // / Set the image to be processed by the next call to predict().
4642 // / The image must be in BGR8 format (as delivered by cv_bridge).
47- void set_image (const cv::Mat & bgr_image);
43+ void set_image (const cv::Mat& bgr_image);
4844
4945 // / Run the network on the current image (no-op if already up to date).
5046 void predict ();
5147
52- std::vector<Candidate> get_detection_candidates_for (const std::string & class_name);
53- cv::Mat get_segmentation_mask_for (const std::string & class_name);
48+ std::vector<Candidate> get_detection_candidates_for (const std::string& class_name);
49+ cv::Mat get_segmentation_mask_for (const std::string& class_name);
5450
55- const std::vector<std::string> & detection_class_names () const ;
56- const std::vector<std::string> & segmentation_class_names () const ;
51+ const std::vector<std::string>& detection_class_names () const ;
52+ const std::vector<std::string>& segmentation_class_names () const ;
5753
58- private:
54+ private:
5955 // ----- ONNX Runtime objects -----
6056 Ort::Env env_;
6157 Ort::SessionOptions session_options_;
6258 std::unique_ptr<Ort::Session> session_;
6359
6460 // ----- Model metadata -----
65- std::vector<int64_t > input_shape_; // [1, 3, H_net, W_net]
61+ std::vector<int64_t > input_shape_; // [1, 3, H_net, W_net]
6662 std::string input_name_;
6763 std::vector<std::string> output_names_;
6864 int num_det_classes_{0 };
@@ -74,7 +70,7 @@ class YoeoHandler
7470 Config cfg_;
7571 rclcpp::Logger logger_;
7672
77- cv::Mat current_image_rgb_; // float32, CHW layout, stored as [C, H, W] in flat vector
73+ cv::Mat current_image_rgb_; // float32, CHW layout, stored as [C, H, W] in flat vector
7874 std::vector<float > input_data_;
7975
8076 bool prediction_is_fresh_{true };
@@ -84,14 +80,14 @@ class YoeoHandler
8480
8581 // Cached results
8682 std::unordered_map<std::string, std::vector<Candidate>> det_results_;
87- std::unordered_map<std::string, cv::Mat> seg_results_; // CV_8UC1 binary masks
83+ std::unordered_map<std::string, cv::Mat> seg_results_; // CV_8UC1 binary masks
8884
8985 // ----- Helpers -----
90- void init_session (const std::string & model_path);
91- void preprocess (const cv::Mat & bgr_image);
86+ void init_session (const std::string& model_path);
87+ void preprocess (const cv::Mat& bgr_image);
9288 void run_inference ();
93- void postprocess_detections (const float * det_data, const std::vector<int64_t > & shape);
94- void postprocess_segmentation (const float * seg_data, const std::vector<int64_t > & shape);
89+ void postprocess_detections (const float * det_data, const std::vector<int64_t >& shape);
90+ void postprocess_segmentation (const uint8_t * seg_data, const std::vector<int64_t >& shape);
9591};
9692
9793} // namespace bitbots_vision
0 commit comments