Skip to content

Commit 0d23eff

Browse files
committed
Fix segmentation type case and format
Signed-off-by: Florian Vahl <florian@flova.de>
1 parent 86d74ec commit 0d23eff

16 files changed

Lines changed: 373 additions & 638 deletions

src/bitbots_motion/bitbots_head_mover/src/move_head.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -146,8 +146,7 @@ class HeadMover {
146146
// Bring the goal point into the planning frame
147147
geometry_msgs::msg::PointStamped new_point;
148148
try {
149-
new_point =
150-
tf_buffer_->transform(goal->look_at_position, "base_footprint", tf2::durationFromSec(0.9));
149+
new_point = tf_buffer_->transform(goal->look_at_position, "base_footprint", tf2::durationFromSec(0.9));
151150
} catch (tf2::TransformException& ex) {
152151
RCLCPP_ERROR(node_->get_logger(), "Could not transform goal point: %s", ex.what());
153152
return rclcpp_action::GoalResponse::REJECT;

src/bitbots_vision/include/bitbots_vision/candidate.hpp

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,7 @@ struct Candidate {
1313
int height{0}; ///< Box height
1414
float rating{0.f};
1515

16-
static Candidate from_x1y1x2y2(int x1, int y1, int x2, int y2, float rating)
17-
{
16+
static Candidate from_x1y1x2y2(int x1, int y1, int x2, int y2, float rating) {
1817
return {std::min(x1, x2), std::min(y1, y2), std::abs(x2 - x1), std::abs(y2 - y1), rating};
1918
}
2019

@@ -24,18 +23,15 @@ struct Candidate {
2423
int y2() const { return y1 + height; }
2524
int radius() const { return (width + height) / 4; }
2625

27-
static std::vector<Candidate> sort_by_rating(std::vector<Candidate> candidates)
28-
{
26+
static std::vector<Candidate> sort_by_rating(std::vector<Candidate> candidates) {
2927
std::sort(candidates.begin(), candidates.end(),
30-
[](const Candidate & a, const Candidate & b) { return a.rating > b.rating; });
28+
[](const Candidate& a, const Candidate& b) { return a.rating > b.rating; });
3129
return candidates;
3230
}
3331

34-
static std::vector<Candidate> filter_by_rating(
35-
const std::vector<Candidate> & candidates, float threshold)
36-
{
32+
static std::vector<Candidate> filter_by_rating(const std::vector<Candidate>& candidates, float threshold) {
3733
std::vector<Candidate> result;
38-
for (const auto & c : candidates) {
34+
for (const auto& c : candidates) {
3935
if (c.rating > threshold) {
4036
result.push_back(c);
4137
}

src/bitbots_vision/include/bitbots_vision/debug_image.hpp

Lines changed: 11 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -7,40 +7,33 @@
77
namespace bitbots_vision {
88

99
/// Draws detection and segmentation results on a copy of the input image.
10-
class DebugImage
11-
{
12-
public:
10+
class DebugImage {
11+
public:
1312
// BGR color constants (cv::Scalar is not a literal type, so inline static const)
14-
inline static const cv::Scalar kBall{0, 255, 0}; // green
13+
inline static const cv::Scalar kBall{0, 255, 0}; // green
1514
inline static const cv::Scalar kRobotTeamMates{255, 255, 102}; // cyan
1615
inline static const cv::Scalar kRobotOpponents{153, 51, 255}; // magenta
1716
inline static const cv::Scalar kRobotUnknown{160, 160, 160}; // grey
18-
inline static const cv::Scalar kGoalposts{255, 255, 255}; // white
19-
inline static const cv::Scalar kLines{255, 0, 0}; // blue
17+
inline static const cv::Scalar kGoalposts{255, 255, 255}; // white
18+
inline static const cv::Scalar kLines{255, 0, 0}; // blue
2019

2120
explicit DebugImage(bool active = false) : active_(active) {}
2221

23-
void set_image(const cv::Mat & image);
22+
void set_image(const cv::Mat& image);
2423

2524
/// Draw a circle around each candidate (suitable for ball-shaped objects).
26-
void draw_ball_candidates(
27-
const std::vector<Candidate> & candidates,
28-
const cv::Scalar & color,
29-
int thickness = 1);
25+
void draw_ball_candidates(const std::vector<Candidate>& candidates, const cv::Scalar& color, int thickness = 1);
3026

3127
/// Draw a bounding rectangle for each candidate (suitable for robots, goalposts).
32-
void draw_box_candidates(
33-
const std::vector<Candidate> & candidates,
34-
const cv::Scalar & color,
35-
int thickness = 1);
28+
void draw_box_candidates(const std::vector<Candidate>& candidates, const cv::Scalar& color, int thickness = 1);
3629

3730
/// Blend a segmentation mask over the debug image.
3831
/// @param mask CV_8UC1 binary mask (non-zero = active)
39-
void draw_mask(const cv::Mat & mask, const cv::Scalar & color, double opacity = 0.5);
32+
void draw_mask(const cv::Mat& mask, const cv::Scalar& color, double opacity = 0.5);
4033

41-
const cv::Mat & get_image() const { return image_; }
34+
const cv::Mat& get_image() const { return image_; }
4235

43-
private:
36+
private:
4437
bool active_{false};
4538
cv::Mat image_;
4639
};

src/bitbots_vision/include/bitbots_vision/model_config.hpp

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,22 +6,21 @@
66
namespace bitbots_vision {
77

88
/// Configuration loaded from the model's `model_config.yaml`.
9-
class ModelConfig
10-
{
11-
public:
9+
class ModelConfig {
10+
public:
1211
ModelConfig() = default;
1312

1413
/// Load from `<model_dir>/model_config.yaml`. Throws on error.
15-
static ModelConfig load_from(const std::string & model_dir);
14+
static ModelConfig load_from(const std::string& model_dir);
1615

17-
const std::vector<std::string> & detection_classes() const { return detection_classes_; }
18-
const std::vector<std::string> & segmentation_classes() const { return segmentation_classes_; }
16+
const std::vector<std::string>& detection_classes() const { return detection_classes_; }
17+
const std::vector<std::string>& segmentation_classes() const { return segmentation_classes_; }
1918
bool team_colors_provided() const { return team_colors_provided_; }
2019

2120
/// Indices of all classes whose name contains "robot".
2221
std::vector<int> robot_class_ids() const;
2322

24-
private:
23+
private:
2524
std::vector<std::string> detection_classes_;
2625
std::vector<std::string> segmentation_classes_;
2726
bool team_colors_provided_{false};
Lines changed: 22 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
#pragma once
22

3+
#include <onnxruntime/core/session/onnxruntime_cxx_api.h>
4+
35
#include <memory>
6+
#include <opencv2/core.hpp>
7+
#include <rclcpp/logger.hpp>
48
#include <string>
59
#include <unordered_map>
610
#include <vector>
711

8-
#include <onnxruntime/core/session/onnxruntime_cxx_api.h>
9-
#include <opencv2/core.hpp>
10-
#include <rclcpp/logger.hpp>
11-
1212
#include "bitbots_vision/candidate.hpp"
1313
#include "bitbots_vision/model_config.hpp"
1414
#include "bitbots_vision/yoeo_processing.hpp"
@@ -25,44 +25,40 @@ namespace bitbots_vision {
2525
/// - two outputs:
2626
/// [0] detections shape [1, N, 5+num_classes] (x_c, y_c, w, h, obj_conf, class_probs…)
2727
/// [1] segmentation shape [1, H_seg, W_seg] (argmax class index per pixel)
28-
class YoeoHandler
29-
{
30-
public:
28+
class YoeoHandler {
29+
public:
3130
struct Config {
3231
float conf_threshold{0.5f};
3332
float nms_threshold{0.4f};
3433
};
3534

36-
YoeoHandler(
37-
const std::string & model_path,
38-
const ModelConfig & model_config,
39-
const Config & cfg,
40-
const rclcpp::Logger & logger);
35+
YoeoHandler(const std::string& model_path, const ModelConfig& model_config, const Config& cfg,
36+
const rclcpp::Logger& logger);
4137

4238
/// Update thresholds without reloading the model.
43-
void reconfigure(const Config & cfg);
39+
void reconfigure(const Config& cfg);
4440

4541
/// Set the image to be processed by the next call to predict().
4642
/// The image must be in BGR8 format (as delivered by cv_bridge).
47-
void set_image(const cv::Mat & bgr_image);
43+
void set_image(const cv::Mat& bgr_image);
4844

4945
/// Run the network on the current image (no-op if already up to date).
5046
void predict();
5147

52-
std::vector<Candidate> get_detection_candidates_for(const std::string & class_name);
53-
cv::Mat get_segmentation_mask_for(const std::string & class_name);
48+
std::vector<Candidate> get_detection_candidates_for(const std::string& class_name);
49+
cv::Mat get_segmentation_mask_for(const std::string& class_name);
5450

55-
const std::vector<std::string> & detection_class_names() const;
56-
const std::vector<std::string> & segmentation_class_names() const;
51+
const std::vector<std::string>& detection_class_names() const;
52+
const std::vector<std::string>& segmentation_class_names() const;
5753

58-
private:
54+
private:
5955
// ----- ONNX Runtime objects -----
6056
Ort::Env env_;
6157
Ort::SessionOptions session_options_;
6258
std::unique_ptr<Ort::Session> session_;
6359

6460
// ----- Model metadata -----
65-
std::vector<int64_t> input_shape_; // [1, 3, H_net, W_net]
61+
std::vector<int64_t> input_shape_; // [1, 3, H_net, W_net]
6662
std::string input_name_;
6763
std::vector<std::string> output_names_;
6864
int num_det_classes_{0};
@@ -74,7 +70,7 @@ class YoeoHandler
7470
Config cfg_;
7571
rclcpp::Logger logger_;
7672

77-
cv::Mat current_image_rgb_; // float32, CHW layout, stored as [C, H, W] in flat vector
73+
cv::Mat current_image_rgb_; // float32, CHW layout, stored as [C, H, W] in flat vector
7874
std::vector<float> input_data_;
7975

8076
bool prediction_is_fresh_{true};
@@ -84,14 +80,14 @@ class YoeoHandler
8480

8581
// Cached results
8682
std::unordered_map<std::string, std::vector<Candidate>> det_results_;
87-
std::unordered_map<std::string, cv::Mat> seg_results_; // CV_8UC1 binary masks
83+
std::unordered_map<std::string, cv::Mat> seg_results_; // CV_8UC1 binary masks
8884

8985
// ----- Helpers -----
90-
void init_session(const std::string & model_path);
91-
void preprocess(const cv::Mat & bgr_image);
86+
void init_session(const std::string& model_path);
87+
void preprocess(const cv::Mat& bgr_image);
9288
void run_inference();
93-
void postprocess_detections(const float * det_data, const std::vector<int64_t> & shape);
94-
void postprocess_segmentation(const float * seg_data, const std::vector<int64_t> & shape);
89+
void postprocess_detections(const float* det_data, const std::vector<int64_t>& shape);
90+
void postprocess_segmentation(const uint8_t* seg_data, const std::vector<int64_t>& shape);
9591
};
9692

9793
} // namespace bitbots_vision

src/bitbots_vision/include/bitbots_vision/yoeo_processing.hpp

Lines changed: 11 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
11
#pragma once
22

3+
#include <opencv2/core.hpp>
34
#include <string>
45
#include <unordered_map>
56
#include <vector>
67

7-
#include <opencv2/core.hpp>
8-
98
#include "bitbots_vision/candidate.hpp"
109

1110
/// Pure preprocessing / postprocessing functions for YOEO inference.
@@ -23,7 +22,7 @@ namespace bitbots_vision::processing {
2322
struct PreprocessInfo {
2423
int orig_h{0};
2524
int orig_w{0};
26-
int max_dim{0}; ///< max(orig_h, orig_w) == side length of the padded square
25+
int max_dim{0}; ///< max(orig_h, orig_w) == side length of the padded square
2726
int pad_top{0};
2827
int pad_bottom{0};
2928
int pad_left{0};
@@ -46,8 +45,7 @@ struct PreprocessInfo {
4645
/// @param net_w Network input width
4746
/// @param info [out] Padding / sizing metadata for use in postprocessors
4847
/// @return Flattened CHW float32 tensor of size 3 * net_h * net_w
49-
std::vector<float> preprocess_image(
50-
const cv::Mat & bgr, int net_h, int net_w, PreprocessInfo & info);
48+
std::vector<float> preprocess_image(const cv::Mat& bgr, int net_h, int net_w, PreprocessInfo& info);
5149

5250
// ---------------------------------------------------------------------------
5351
// Non-maximum suppression
@@ -66,13 +64,9 @@ std::vector<float> preprocess_image(
6664
/// @param nms_threshold IoU threshold for suppression
6765
/// @param max_detections Upper limit on kept boxes
6866
/// @return Indices of kept boxes
69-
std::vector<int> nms_boxes(
70-
const std::vector<cv::Rect2d> & boxes,
71-
const std::vector<float> & scores,
72-
const std::vector<int> & class_ids,
73-
const std::vector<int> & robot_class_ids,
74-
float nms_threshold,
75-
int max_detections = 30);
67+
std::vector<int> nms_boxes(const std::vector<cv::Rect2d>& boxes, const std::vector<float>& scores,
68+
const std::vector<int>& class_ids, const std::vector<int>& robot_class_ids,
69+
float nms_threshold, int max_detections = 30);
7670

7771
// ---------------------------------------------------------------------------
7872
// Detection postprocessing
@@ -94,14 +88,8 @@ std::vector<int> nms_boxes(
9488
/// @param info Preprocessing metadata for coordinate rescaling
9589
/// @return Map of class_name → kept Candidates
9690
std::unordered_map<std::string, std::vector<Candidate>> postprocess_detections(
97-
const float * data,
98-
int64_t num_boxes,
99-
int64_t stride,
100-
const std::vector<std::string> & class_names,
101-
const std::vector<int> & robot_class_ids,
102-
float conf_thresh,
103-
float nms_thresh,
104-
const PreprocessInfo & info);
91+
const float* data, int64_t num_boxes, int64_t stride, const std::vector<std::string>& class_names,
92+
const std::vector<int>& robot_class_ids, float conf_thresh, float nms_thresh, const PreprocessInfo& info);
10593

10694
// ---------------------------------------------------------------------------
10795
// Segmentation postprocessing
@@ -122,11 +110,8 @@ std::unordered_map<std::string, std::vector<Candidate>> postprocess_detections(
122110
/// @param class_names Segmentation class names (index == class id)
123111
/// @param info Preprocessing metadata for unpadding
124112
/// @return Map of class_name → CV_8UC1 binary mask
125-
std::unordered_map<std::string, cv::Mat> postprocess_segmentation(
126-
const float * data,
127-
int seg_h,
128-
int seg_w,
129-
const std::vector<std::string> & class_names,
130-
const PreprocessInfo & info);
113+
std::unordered_map<std::string, cv::Mat> postprocess_segmentation(const uint8_t* data, int seg_h, int seg_w,
114+
const std::vector<std::string>& class_names,
115+
const PreprocessInfo& info);
131116

132117
} // namespace bitbots_vision::processing

src/bitbots_vision/src/debug_image.cpp

Lines changed: 7 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -4,39 +4,28 @@
44

55
namespace bitbots_vision {
66

7-
void DebugImage::set_image(const cv::Mat & image)
8-
{
9-
image_ = image.clone();
10-
}
7+
void DebugImage::set_image(const cv::Mat& image) { image_ = image.clone(); }
118

12-
void DebugImage::draw_ball_candidates(
13-
const std::vector<Candidate> & candidates,
14-
const cv::Scalar & color,
15-
int thickness)
16-
{
9+
void DebugImage::draw_ball_candidates(const std::vector<Candidate>& candidates, const cv::Scalar& color,
10+
int thickness) {
1711
if (!active_) {
1812
return;
1913
}
20-
for (const auto & c : candidates) {
14+
for (const auto& c : candidates) {
2115
cv::circle(image_, {c.center_x(), c.center_y()}, c.radius(), color, thickness);
2216
}
2317
}
2418

25-
void DebugImage::draw_box_candidates(
26-
const std::vector<Candidate> & candidates,
27-
const cv::Scalar & color,
28-
int thickness)
29-
{
19+
void DebugImage::draw_box_candidates(const std::vector<Candidate>& candidates, const cv::Scalar& color, int thickness) {
3020
if (!active_) {
3121
return;
3222
}
33-
for (const auto & c : candidates) {
23+
for (const auto& c : candidates) {
3424
cv::rectangle(image_, {c.x1, c.y1}, {c.x2(), c.y2()}, color, thickness);
3525
}
3626
}
3727

38-
void DebugImage::draw_mask(const cv::Mat & mask, const cv::Scalar & color, double opacity)
39-
{
28+
void DebugImage::draw_mask(const cv::Mat& mask, const cv::Scalar& color, double opacity) {
4029
if (!active_ || mask.empty()) {
4130
return;
4231
}

src/bitbots_vision/src/model_config.cpp

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,13 @@
11
#include "bitbots_vision/model_config.hpp"
22

3+
#include <yaml-cpp/yaml.h>
4+
35
#include <filesystem>
46
#include <stdexcept>
57

6-
#include <yaml-cpp/yaml.h>
7-
88
namespace bitbots_vision {
99

10-
ModelConfig ModelConfig::load_from(const std::string & model_dir)
11-
{
10+
ModelConfig ModelConfig::load_from(const std::string& model_dir) {
1211
const std::string path = model_dir + "/model_config.yaml";
1312
if (!std::filesystem::exists(path)) {
1413
throw std::runtime_error("model_config.yaml not found at: " + path);
@@ -18,11 +17,11 @@ ModelConfig ModelConfig::load_from(const std::string & model_dir)
1817

1918
ModelConfig cfg;
2019

21-
for (const auto & name : root["detection"]["classes"]) {
20+
for (const auto& name : root["detection"]["classes"]) {
2221
cfg.detection_classes_.push_back(name.as<std::string>());
2322
}
2423

25-
for (const auto & name : root["segmentation"]["classes"]) {
24+
for (const auto& name : root["segmentation"]["classes"]) {
2625
cfg.segmentation_classes_.push_back(name.as<std::string>());
2726
}
2827

@@ -33,8 +32,7 @@ ModelConfig ModelConfig::load_from(const std::string & model_dir)
3332
return cfg;
3433
}
3534

36-
std::vector<int> ModelConfig::robot_class_ids() const
37-
{
35+
std::vector<int> ModelConfig::robot_class_ids() const {
3836
std::vector<int> ids;
3937
for (int i = 0; i < static_cast<int>(detection_classes_.size()); ++i) {
4038
if (detection_classes_[i].find("robot") != std::string::npos) {

0 commit comments

Comments
 (0)