Skip to content

Commit 66478b6

Browse files
authored
Add semantic segmentation tutorial (#136)
* Add semantic segmentation demo packages with FP16 model Signed-off-by: pepisg <pedro.gonzalez@eia.edu.co> * segmentation node improvements Signed-off-by: pepisg <pedro.gonzalez@eia.edu.co> * remove metapackage and comments Signed-off-by: pepisg <pedro.gonzalez@eia.edu.co> --------- Signed-off-by: pepisg <pedro.gonzalez@eia.edu.co>
1 parent 79ed9a5 commit 66478b6

19 files changed

Lines changed: 2475 additions & 0 deletions

File tree

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
# Semantic Segmentation Node
2+
3+
ROS2 node for real-time semantic segmentation inference using ONNX Runtime.
4+
5+
## Overview
6+
7+
This node performs semantic segmentation on camera images and publishes segmentation masks, confidence maps, and colored overlays. It uses ONNX Runtime for efficient inference without requiring PyTorch or super-gradients at runtime.
8+
9+
## Topics
10+
11+
**Subscribed:**
12+
- `/rgbd_camera/image` (sensor_msgs/Image) - Input RGB camera images
13+
14+
**Published:**
15+
- `/segmentation/mask` (sensor_msgs/Image) - Segmentation mask with class IDs (mono8)
16+
- `/segmentation/confidence` (sensor_msgs/Image) - Per-pixel confidence (mono8, 0-255)
17+
- `/segmentation/overlay` (sensor_msgs/Image) - Colored overlay visualization (bgr8)
18+
- `/segmentation/label_info` (vision_msgs/LabelInfo) - Class mappings (latched)
19+
20+
## Model
21+
22+
The ONNX model (`models/model.onnx`) can be generated using the [Simple Segmentation Toolkit](https://github.com/pepisg/simple_segmentation_toolkit).
23+
24+
### Training Your Own Model
25+
26+
1. Capture training images from a real robot or from Gazebo, with varying lighting and environmental conditions
27+
2. Use the Simple Segmentation Toolkit to label and train a model
28+
3. Convert the trained model to ONNX format: `python3 convert_to_onnx.py`
29+
4. Copy `model.onnx` to this package's `models/` directory
30+
31+
The ontology configuration (`config/ontology.yaml`) must match the classes used during training.
32+
33+
## Usage
34+
35+
```bash
36+
ros2 run semantic_segmentation_node segmentation_node
37+
```
38+
39+
All dependencies are included in the devcontainer.
40+
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# Ontology configuration for semantic segmentation
2+
# Defines the classes to detect and their corresponding text prompts
3+
4+
ontology:
5+
# List of classes to detect
6+
# Classes should be defined in the same order as the model output. 0 is always the background class.
7+
# This is used for building the label_info message, which maps class IDs to class names.
8+
# and the colors are used for the mask's visualization.
9+
classes:
10+
- name: sidewalk
11+
color: [255, 0, 0] # BGR format: Blue
12+
- name: grass
13+
color: [0, 255, 0] # BGR format: Green
14+
15+
# Model settings
16+
model:
17+
device: cpu # cuda or cpu
Binary file not shown.
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
<?xml version="1.0"?>
2+
<?xml-model href="http://download.ros.org/schema/package_format3.xsd" schematypens="http://www.w3.org/2001/XMLSchema"?>
3+
<package format="3">
4+
<name>semantic_segmentation_node</name>
5+
<version>0.0.1</version>
6+
<description>ROS2 node for semantic segmentation inference</description>
7+
<maintainer email="pedro.gonzalez@eia.edu.co">Pedro Gonzalez</maintainer>
8+
<license>BSD-3-Clause</license>
9+
10+
<depend>rclpy</depend>
11+
<depend>sensor_msgs</depend>
12+
<depend>cv_bridge</depend>
13+
<depend>std_msgs</depend>
14+
<depend>vision_msgs</depend>
15+
16+
<export>
17+
<build_type>ament_python</build_type>
18+
</export>
19+
</package>
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
PyYAML>=5.4.0
2+
opencv-python>=4.5.0
3+
numpy>=1.19.0
4+
5+
# ONNX Runtime for model inference
6+
# For CPU-only: onnxruntime
7+
# For GPU support: onnxruntime-gpu
8+
onnxruntime>=1.10.0
9+
10+

nav2_semantic_segmentation_demo/semantic_segmentation_node/resource/semantic_segmentation_node

Whitespace-only changes.

nav2_semantic_segmentation_demo/semantic_segmentation_node/semantic_segmentation_node/__init__.py

Whitespace-only changes.
Lines changed: 255 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,255 @@
1+
#!/usr/bin/env python3
2+
"""ROS2 node for semantic segmentation inference using ONNX Runtime."""
3+
4+
import rclpy
5+
from rclpy.node import Node
6+
from rclpy.qos import QoSProfile, DurabilityPolicy
7+
from sensor_msgs.msg import Image
8+
from vision_msgs.msg import LabelInfo, VisionClass
9+
from cv_bridge import CvBridge
10+
import cv2
11+
import numpy as np
12+
import onnxruntime as ort
13+
import yaml
14+
from pathlib import Path
15+
from ament_index_python.packages import get_package_share_directory
16+
17+
18+
19+
class SegmentationNode(Node):
20+
"""
21+
ROS2 node that performs semantic segmentation using ONNX Runtime.
22+
23+
NOTE:
24+
This node runs on CPU by default for compatibility with all hardware, but can run on GPU if you install the required
25+
ONNX Runtime GPU dependencies. See instructions at:
26+
https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html#requirements
27+
"""
28+
29+
def __init__(self):
30+
super().__init__('segmentation_node')
31+
32+
# Get package share directory using ament_index
33+
package_share = Path(get_package_share_directory('semantic_segmentation_node'))
34+
model_path = package_share / 'models' / 'model.onnx'
35+
config_path = package_share / 'config' / 'ontology.yaml'
36+
37+
# Load config
38+
with open(config_path, 'r') as f:
39+
config = yaml.safe_load(f)
40+
self.class_names = [cls['name'] for cls in config['ontology']['classes']]
41+
self.class_colors = [cls['color'] for cls in config['ontology']['classes']] # BGR format
42+
self.num_classes = len(self.class_names) + 1 # +1 for background
43+
44+
# Get device setting from config
45+
device = config.get('model', {}).get('device', 'cpu').lower()
46+
47+
self.get_logger().info(f'Loading ONNX model from: {model_path}')
48+
self.get_logger().info(f'Number of classes: {self.num_classes}')
49+
self.get_logger().info(f'Device setting: {device}')
50+
51+
# Set providers based on device setting
52+
if device == 'cuda':
53+
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
54+
else:
55+
providers = ['CPUExecutionProvider']
56+
57+
self.session = ort.InferenceSession(str(model_path), providers=providers)
58+
59+
# Get model device
60+
provider = self.session.get_providers()[0]
61+
self.get_logger().info(f'Using provider: {provider}')
62+
63+
# Detect model input type (FP32 or FP16)
64+
input_meta = self.session.get_inputs()[0]
65+
self.input_dtype = input_meta.type
66+
self.use_fp16 = 'float16' in str(self.input_dtype).lower()
67+
self.get_logger().info(f'Model input type: {self.input_dtype} (FP16: {self.use_fp16})')
68+
69+
# Image normalization (ImageNet normalization)
70+
# Use float16 for mean/std if model expects FP16, otherwise float32
71+
dtype = np.float16 if self.use_fp16 else np.float32
72+
self.mean = np.array([0.485, 0.456, 0.406], dtype=dtype).reshape(1, 3, 1, 1)
73+
self.std = np.array([0.229, 0.224, 0.225], dtype=dtype).reshape(1, 3, 1, 1)
74+
75+
# CV bridge
76+
self.bridge = CvBridge()
77+
78+
# Declare parameters
79+
self.declare_parameter('input_topic', '/rgbd_camera/image')
80+
self.declare_parameter('mask_topic', '/segmentation/mask')
81+
self.declare_parameter('confidence_topic', '/segmentation/confidence')
82+
self.declare_parameter('label_info_topic', '/segmentation/label_info')
83+
self.declare_parameter('overlay_topic', '/segmentation/overlay')
84+
self.declare_parameter('publish_overlay', True)
85+
86+
# Get parameters
87+
input_topic = self.get_parameter('input_topic').get_parameter_value().string_value
88+
mask_topic = self.get_parameter('mask_topic').get_parameter_value().string_value
89+
confidence_topic = self.get_parameter('confidence_topic').get_parameter_value().string_value
90+
label_info_topic = self.get_parameter('label_info_topic').get_parameter_value().string_value
91+
overlay_topic = self.get_parameter('overlay_topic').get_parameter_value().string_value
92+
publish_overlay = self.get_parameter('publish_overlay').get_parameter_value().bool_value
93+
94+
# Create subscribers and publishers
95+
self.subscription = self.create_subscription(
96+
Image,
97+
input_topic,
98+
self.image_callback,
99+
10
100+
)
101+
102+
self.mask_publisher = self.create_publisher(
103+
Image,
104+
mask_topic,
105+
10
106+
)
107+
108+
self.confidence_publisher = self.create_publisher(
109+
Image,
110+
confidence_topic,
111+
10
112+
)
113+
114+
# Create overlay publisher if enabled
115+
self.overlay_publisher = None
116+
if publish_overlay:
117+
self.overlay_publisher = self.create_publisher(
118+
Image,
119+
overlay_topic,
120+
10
121+
)
122+
123+
# Create LabelInfo publisher with transient local QoS
124+
label_info_qos = QoSProfile(
125+
depth=1,
126+
durability=DurabilityPolicy.TRANSIENT_LOCAL
127+
)
128+
self.label_info_publisher = self.create_publisher(
129+
LabelInfo,
130+
label_info_topic,
131+
label_info_qos
132+
)
133+
134+
# Create and publish LabelInfo message
135+
self.publish_label_info()
136+
137+
self.get_logger().info(f'Subscribing to: {input_topic}')
138+
self.get_logger().info(f'Publishing mask to: {mask_topic}')
139+
self.get_logger().info(f'Publishing confidence to: {confidence_topic}')
140+
self.get_logger().info(f'Publishing label info to: {label_info_topic}')
141+
if publish_overlay:
142+
self.get_logger().info(f'Publishing overlay to: {overlay_topic}')
143+
144+
def publish_label_info(self):
145+
"""Publish LabelInfo message with class mappings."""
146+
label_info = LabelInfo()
147+
label_info.header.stamp = self.get_clock().now().to_msg()
148+
label_info.header.frame_id = '' # Not tied to a specific frame
149+
150+
# Build class map: background is class 0, then classes from config
151+
class_map = []
152+
153+
# Background class (class 0)
154+
bg_class = VisionClass()
155+
bg_class.class_id = 0
156+
bg_class.class_name = 'background'
157+
class_map.append(bg_class)
158+
159+
# Add classes from ontology
160+
for idx, class_name in enumerate(self.class_names, start=1):
161+
vc = VisionClass()
162+
vc.class_id = idx
163+
vc.class_name = class_name
164+
class_map.append(vc)
165+
166+
label_info.class_map = class_map
167+
label_info.threshold = 0.5 # Default confidence threshold
168+
169+
self.label_info_publisher.publish(label_info)
170+
self.get_logger().info(f'Published LabelInfo with {len(class_map)} classes')
171+
172+
def create_colored_mask(self, mask: np.ndarray) -> np.ndarray:
173+
"""
174+
Convert class ID mask to colored visualization.
175+
176+
Args:
177+
mask: Single-channel mask with class IDs [H, W]
178+
179+
Returns:
180+
Colored mask in BGR format [H, W, 3]
181+
"""
182+
h, w = mask.shape
183+
colored = np.zeros((h, w, 3), dtype=np.uint8)
184+
185+
# Background stays black
186+
for class_id in range(1, self.num_classes):
187+
if class_id <= len(self.class_colors):
188+
color = self.class_colors[class_id - 1]
189+
colored[mask == class_id] = color
190+
191+
return colored
192+
193+
def image_callback(self, msg):
194+
"""Process incoming image and publish segmentation results."""
195+
# Convert ROS image to OpenCV format (BGR)
196+
cv_image = self.bridge.imgmsg_to_cv2(msg, desired_encoding='bgr8')
197+
198+
# Convert BGR to RGB
199+
rgb_image = cv2.cvtColor(cv_image, cv2.COLOR_BGR2RGB)
200+
201+
# Preprocess image
202+
# Convert to the model's expected dtype (FP16 or FP32)
203+
dtype = np.float16 if self.use_fp16 else np.float32
204+
input_tensor = rgb_image.transpose(2, 0, 1).astype(dtype) / 255.0
205+
# Apply ImageNet normalization
206+
input_tensor = (input_tensor - self.mean.squeeze(0)) / self.std.squeeze(0)
207+
# Add batch dimension
208+
input_tensor = np.expand_dims(input_tensor, axis=0)
209+
210+
# Run ONNX inference
211+
outputs = self.session.run(None, {'input': input_tensor})
212+
output = outputs[0] # Shape: [1, num_classes, H, W]
213+
214+
# Get prediction (class IDs)
215+
prediction = np.argmax(output, axis=1).squeeze(0).astype(np.uint8)
216+
217+
# Get confidence (max probability per pixel)
218+
# Apply softmax manually
219+
exp_output = np.exp(output - np.max(output, axis=1, keepdims=True))
220+
probabilities = exp_output / np.sum(exp_output, axis=1, keepdims=True)
221+
confidence = np.max(probabilities, axis=1).squeeze(0)
222+
confidence_uint8 = (confidence * 255.0).astype(np.uint8)
223+
224+
# Create mask image message
225+
mask_msg = self.bridge.cv2_to_imgmsg(prediction, encoding='mono8')
226+
mask_msg.header = msg.header
227+
228+
# Create confidence image message
229+
confidence_msg = self.bridge.cv2_to_imgmsg(confidence_uint8, encoding='mono8')
230+
confidence_msg.header = msg.header
231+
232+
# Publish mask and confidence
233+
self.mask_publisher.publish(mask_msg)
234+
self.confidence_publisher.publish(confidence_msg)
235+
236+
# Create and publish overlay if enabled
237+
if self.overlay_publisher is not None:
238+
pred_colored = self.create_colored_mask(prediction)
239+
overlay = cv2.addWeighted(cv_image, 0.7, pred_colored, 0.3, 0)
240+
overlay_msg = self.bridge.cv2_to_imgmsg(overlay, encoding='bgr8')
241+
overlay_msg.header = msg.header
242+
self.overlay_publisher.publish(overlay_msg)
243+
244+
245+
def main(args=None):
246+
rclpy.init(args=args)
247+
node = SegmentationNode()
248+
rclpy.spin(node)
249+
node.destroy_node()
250+
rclpy.shutdown()
251+
252+
253+
if __name__ == '__main__':
254+
main()
255+
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
[develop]
2+
script_dir=$base/lib/semantic_segmentation_node
3+
[install]
4+
install_scripts=$base/lib/semantic_segmentation_node
5+
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
from setuptools import find_packages, setup
2+
import os
3+
from glob import glob
4+
5+
package_name = 'semantic_segmentation_node'
6+
7+
setup(
8+
name='semantic_segmentation_node',
9+
version='0.0.0',
10+
packages=find_packages(exclude=['test']),
11+
data_files=[
12+
('share/ament_index/resource_index/packages',
13+
['resource/' + package_name]),
14+
('share/' + package_name, ['package.xml']),
15+
(os.path.join('share', package_name, 'models'), glob('models/*.onnx')),
16+
(os.path.join('share', package_name, 'config'), glob('config/*.yaml')),
17+
],
18+
install_requires=['setuptools'],
19+
zip_safe=True,
20+
maintainer='ros',
21+
maintainer_email='pedro.gonzalez@eia.edu.co',
22+
description='ROS2 node for semantic segmentation inference',
23+
license='BSD-3-Clause',
24+
extras_require={
25+
'test': [
26+
'pytest',
27+
],
28+
},
29+
entry_points={
30+
'console_scripts': [
31+
'segmentation_node = semantic_segmentation_node.segmentation_node:main',
32+
],
33+
},
34+
)

0 commit comments

Comments
 (0)