Skip to content

Commit e7e134f

Browse files
committed
Add semantic segmentation demo packages with FP16 model
Signed-off-by: pepisg <pedro.gonzalez@eia.edu.co>
1 parent 79ed9a5 commit e7e134f

21 files changed

Lines changed: 2483 additions & 0 deletions

File tree

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
cmake_minimum_required(VERSION 3.8)
2+
project(nav2_semantic_segmentation_demo)
3+
4+
find_package(ament_cmake REQUIRED)
5+
6+
ament_package()
7+
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
<?xml version="1.0"?>
2+
<?xml-model href="http://download.ros.org/schema/package_format3.xsd" schematypens="http://www.w3.org/2001/XMLSchema"?>
3+
<package format="3">
4+
<name>nav2_semantic_segmentation_demo</name>
5+
<version>0.0.0</version>
6+
<description>Metapackage for semantic segmentation demo packages</description>
7+
<maintainer email="user@example.com">User</maintainer>
8+
<license>BSD-3-Clause</license>
9+
10+
<buildtool_depend>ament_cmake</buildtool_depend>
11+
12+
<exec_depend>semantic_segmentation_node</exec_depend>
13+
<exec_depend>semantic_segmentation_sim</exec_depend>
14+
15+
<export>
16+
<build_type>ament_cmake</build_type>
17+
</export>
18+
</package>
19+
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
# Semantic Segmentation Node
2+
3+
ROS2 node for real-time semantic segmentation inference using ONNX Runtime.
4+
5+
## Overview
6+
7+
This node performs semantic segmentation on camera images and publishes segmentation masks, confidence maps, and colored overlays. It uses ONNX Runtime for efficient inference without requiring PyTorch or super-gradients at runtime.
8+
9+
## Topics
10+
11+
**Subscribed:**
12+
- `/rgbd_camera/image` (sensor_msgs/Image) - Input RGB camera images
13+
14+
**Published:**
15+
- `/segmentation/mask` (sensor_msgs/Image) - Segmentation mask with class IDs (mono8)
16+
- `/segmentation/confidence` (sensor_msgs/Image) - Per-pixel confidence (mono8, 0-255)
17+
- `/segmentation/overlay` (sensor_msgs/Image) - Colored overlay visualization (bgr8)
18+
- `/segmentation/label_info` (vision_msgs/LabelInfo) - Class mappings (latched)
19+
20+
## Model
21+
22+
The ONNX model (`models/model.onnx`) can be generated using the [Simple Segmentation Toolkit](https://github.com/pepisg/simple_segmentation_toolkit).
23+
24+
### Training Your Own Model
25+
26+
1. Capture training images from Gazebo with varying lighting and environmental conditions
27+
2. Use the Simple Segmentation Toolkit to label and train a model
28+
3. Convert the trained model to ONNX format: `python3 convert_to_onnx.py`
29+
4. Copy `model.onnx` to this package's `models/` directory
30+
31+
The ontology configuration (`config/ontology.yaml`) must match the classes used during training.
32+
33+
## Usage
34+
35+
```bash
36+
ros2 run semantic_segmentation_node segmentation_node
37+
```
38+
39+
All dependencies are included in the devcontainer.
40+
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# Ontology configuration for semantic segmentation
2+
# Defines the classes to detect and their corresponding text prompts
3+
4+
ontology:
5+
# List of classes to detect
6+
# Each entry maps a class name to its text prompt for GroundingDINO
7+
classes:
8+
- name: sidewalk
9+
color: [255, 0, 0] # BGR format: Blue
10+
- name: grass
11+
color: [0, 255, 0] # BGR format: Green
12+
13+
# Model settings
14+
model:
15+
device: cpu # cuda or cpu
Binary file not shown.
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
<?xml version="1.0"?>
2+
<?xml-model href="http://download.ros.org/schema/package_format3.xsd" schematypens="http://www.w3.org/2001/XMLSchema"?>
3+
<package format="3">
4+
<name>semantic_segmentation_node</name>
5+
<version>0.0.0</version>
6+
<description>ROS2 node for semantic segmentation inference</description>
7+
<maintainer email="pedro.gonzalez@eia.edu.co">ros</maintainer>
8+
<license>BSD-3-Clause</license>
9+
10+
<depend>rclpy</depend>
11+
<depend>sensor_msgs</depend>
12+
<depend>cv_bridge</depend>
13+
<depend>std_msgs</depend>
14+
<depend>vision_msgs</depend>
15+
16+
<export>
17+
<build_type>ament_python</build_type>
18+
</export>
19+
</package>
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
PyYAML>=5.4.0
2+
opencv-python>=4.5.0
3+
numpy>=1.19.0
4+
5+
# ONNX Runtime for model inference
6+
# For CPU-only: onnxruntime
7+
# For GPU support: onnxruntime-gpu
8+
onnxruntime>=1.10.0
9+
10+

nav2_semantic_segmentation_demo/semantic_segmentation_node/resource/semantic_segmentation_node

Whitespace-only changes.

nav2_semantic_segmentation_demo/semantic_segmentation_node/semantic_segmentation_node/__init__.py

Whitespace-only changes.
Lines changed: 246 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,246 @@
1+
#!/usr/bin/env python3
2+
"""ROS2 node for semantic segmentation inference using ONNX Runtime."""
3+
4+
import rclpy
5+
from rclpy.node import Node
6+
from sensor_msgs.msg import Image
7+
from vision_msgs.msg import LabelInfo, VisionClass
8+
from cv_bridge import CvBridge
9+
import cv2
10+
import numpy as np
11+
import onnxruntime as ort
12+
import yaml
13+
from pathlib import Path
14+
15+
16+
class SegmentationNode(Node):
17+
"""ROS2 node that performs semantic segmentation using ONNX Runtime."""
18+
19+
def __init__(self):
20+
super().__init__('segmentation_node')
21+
22+
# Get package share directory using ament_index
23+
from ament_index_python.packages import get_package_share_directory
24+
package_share = Path(get_package_share_directory('semantic_segmentation_node'))
25+
model_path = package_share / 'models' / 'model.onnx'
26+
config_path = package_share / 'config' / 'ontology.yaml'
27+
28+
# If model doesn't exist (symlink install), try source directory
29+
if not model_path.exists():
30+
import os
31+
src_path = Path(__file__).parent.parent.parent / 'models' / 'model.onnx'
32+
if src_path.exists():
33+
model_path = src_path
34+
35+
# Load config
36+
with open(config_path, 'r') as f:
37+
config = yaml.safe_load(f)
38+
self.class_names = [cls['name'] for cls in config['ontology']['classes']]
39+
self.class_colors = [cls['color'] for cls in config['ontology']['classes']] # BGR format
40+
self.num_classes = len(self.class_names) + 1 # +1 for background
41+
42+
self.get_logger().info(f'Loading ONNX model from: {model_path}')
43+
self.get_logger().info(f'Number of classes: {self.num_classes}')
44+
45+
# Load ONNX model with GPU support if available
46+
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
47+
self.session = ort.InferenceSession(str(model_path), providers=providers)
48+
49+
# Get model device
50+
provider = self.session.get_providers()[0]
51+
self.get_logger().info(f'Using provider: {provider}')
52+
53+
# Detect model input type (FP32 or FP16)
54+
input_meta = self.session.get_inputs()[0]
55+
self.input_dtype = input_meta.type
56+
self.use_fp16 = 'float16' in str(self.input_dtype).lower()
57+
self.get_logger().info(f'Model input type: {self.input_dtype} (FP16: {self.use_fp16})')
58+
59+
# Image normalization (ImageNet normalization)
60+
# Use float16 for mean/std if model expects FP16, otherwise float32
61+
dtype = np.float16 if self.use_fp16 else np.float32
62+
self.mean = np.array([0.485, 0.456, 0.406], dtype=dtype).reshape(1, 3, 1, 1)
63+
self.std = np.array([0.229, 0.224, 0.225], dtype=dtype).reshape(1, 3, 1, 1)
64+
65+
# CV bridge
66+
self.bridge = CvBridge()
67+
68+
# Declare parameters
69+
self.declare_parameter('input_topic', '/rgbd_camera/image')
70+
self.declare_parameter('mask_topic', '/segmentation/mask')
71+
self.declare_parameter('confidence_topic', '/segmentation/confidence')
72+
self.declare_parameter('label_info_topic', '/segmentation/label_info')
73+
self.declare_parameter('overlay_topic', '/segmentation/overlay')
74+
self.declare_parameter('publish_overlay', True)
75+
76+
# Get parameters
77+
input_topic = self.get_parameter('input_topic').get_parameter_value().string_value
78+
mask_topic = self.get_parameter('mask_topic').get_parameter_value().string_value
79+
confidence_topic = self.get_parameter('confidence_topic').get_parameter_value().string_value
80+
label_info_topic = self.get_parameter('label_info_topic').get_parameter_value().string_value
81+
overlay_topic = self.get_parameter('overlay_topic').get_parameter_value().string_value
82+
publish_overlay = self.get_parameter('publish_overlay').get_parameter_value().bool_value
83+
84+
# Create subscribers and publishers
85+
self.subscription = self.create_subscription(
86+
Image,
87+
input_topic,
88+
self.image_callback,
89+
10
90+
)
91+
92+
self.mask_publisher = self.create_publisher(
93+
Image,
94+
mask_topic,
95+
10
96+
)
97+
98+
self.confidence_publisher = self.create_publisher(
99+
Image,
100+
confidence_topic,
101+
10
102+
)
103+
104+
# Create overlay publisher if enabled
105+
self.overlay_publisher = None
106+
if publish_overlay:
107+
self.overlay_publisher = self.create_publisher(
108+
Image,
109+
overlay_topic,
110+
10
111+
)
112+
113+
# Create LabelInfo publisher with transient local QoS
114+
from rclpy.qos import QoSProfile, DurabilityPolicy
115+
label_info_qos = QoSProfile(
116+
depth=1,
117+
durability=DurabilityPolicy.TRANSIENT_LOCAL
118+
)
119+
self.label_info_publisher = self.create_publisher(
120+
LabelInfo,
121+
label_info_topic,
122+
label_info_qos
123+
)
124+
125+
# Create and publish LabelInfo message
126+
self.publish_label_info()
127+
128+
self.get_logger().info(f'Subscribing to: {input_topic}')
129+
self.get_logger().info(f'Publishing mask to: {mask_topic}')
130+
self.get_logger().info(f'Publishing confidence to: {confidence_topic}')
131+
self.get_logger().info(f'Publishing label info to: {label_info_topic}')
132+
if publish_overlay:
133+
self.get_logger().info(f'Publishing overlay to: {overlay_topic}')
134+
135+
def publish_label_info(self):
136+
"""Publish LabelInfo message with class mappings."""
137+
label_info = LabelInfo()
138+
label_info.header.stamp = self.get_clock().now().to_msg()
139+
label_info.header.frame_id = '' # Not tied to a specific frame
140+
141+
# Build class map: background is class 0, then classes from config
142+
class_map = []
143+
144+
# Background class (class 0)
145+
bg_class = VisionClass()
146+
bg_class.class_id = 0
147+
bg_class.class_name = 'background'
148+
class_map.append(bg_class)
149+
150+
# Add classes from ontology
151+
for idx, class_name in enumerate(self.class_names, start=1):
152+
vc = VisionClass()
153+
vc.class_id = idx
154+
vc.class_name = class_name
155+
class_map.append(vc)
156+
157+
label_info.class_map = class_map
158+
label_info.threshold = 0.5 # Default confidence threshold
159+
160+
self.label_info_publisher.publish(label_info)
161+
self.get_logger().info(f'Published LabelInfo with {len(class_map)} classes')
162+
163+
def create_colored_mask(self, mask: np.ndarray) -> np.ndarray:
164+
"""
165+
Convert class ID mask to colored visualization.
166+
167+
Args:
168+
mask: Single-channel mask with class IDs [H, W]
169+
170+
Returns:
171+
Colored mask in BGR format [H, W, 3]
172+
"""
173+
h, w = mask.shape
174+
colored = np.zeros((h, w, 3), dtype=np.uint8)
175+
176+
# Background stays black
177+
for class_id in range(1, self.num_classes):
178+
if class_id <= len(self.class_colors):
179+
color = self.class_colors[class_id - 1]
180+
colored[mask == class_id] = color
181+
182+
return colored
183+
184+
def image_callback(self, msg):
185+
"""Process incoming image and publish segmentation results."""
186+
# Convert ROS image to OpenCV format (BGR)
187+
cv_image = self.bridge.imgmsg_to_cv2(msg, desired_encoding='bgr8')
188+
189+
# Convert BGR to RGB
190+
rgb_image = cv2.cvtColor(cv_image, cv2.COLOR_BGR2RGB)
191+
192+
# Preprocess image
193+
# Convert to the model's expected dtype (FP16 or FP32)
194+
dtype = np.float16 if self.use_fp16 else np.float32
195+
input_tensor = rgb_image.transpose(2, 0, 1).astype(dtype) / 255.0
196+
# Apply ImageNet normalization
197+
input_tensor = (input_tensor - self.mean.squeeze(0)) / self.std.squeeze(0)
198+
# Add batch dimension
199+
input_tensor = np.expand_dims(input_tensor, axis=0)
200+
201+
# Run ONNX inference
202+
outputs = self.session.run(None, {'input': input_tensor})
203+
output = outputs[0] # Shape: [1, num_classes, H, W]
204+
205+
# Get prediction (class IDs)
206+
prediction = np.argmax(output, axis=1).squeeze(0).astype(np.uint8)
207+
208+
# Get confidence (max probability per pixel)
209+
# Apply softmax manually
210+
exp_output = np.exp(output - np.max(output, axis=1, keepdims=True))
211+
probabilities = exp_output / np.sum(exp_output, axis=1, keepdims=True)
212+
confidence = np.max(probabilities, axis=1).squeeze(0)
213+
confidence_uint8 = (confidence * 255.0).astype(np.uint8)
214+
215+
# Create mask image message
216+
mask_msg = self.bridge.cv2_to_imgmsg(prediction, encoding='mono8')
217+
mask_msg.header = msg.header
218+
219+
# Create confidence image message
220+
confidence_msg = self.bridge.cv2_to_imgmsg(confidence_uint8, encoding='mono8')
221+
confidence_msg.header = msg.header
222+
223+
# Publish mask and confidence
224+
self.mask_publisher.publish(mask_msg)
225+
self.confidence_publisher.publish(confidence_msg)
226+
227+
# Create and publish overlay if enabled
228+
if self.overlay_publisher is not None:
229+
pred_colored = self.create_colored_mask(prediction)
230+
overlay = cv2.addWeighted(cv_image, 0.7, pred_colored, 0.3, 0)
231+
overlay_msg = self.bridge.cv2_to_imgmsg(overlay, encoding='bgr8')
232+
overlay_msg.header = msg.header
233+
self.overlay_publisher.publish(overlay_msg)
234+
235+
236+
def main(args=None):
237+
rclpy.init(args=args)
238+
node = SegmentationNode()
239+
rclpy.spin(node)
240+
node.destroy_node()
241+
rclpy.shutdown()
242+
243+
244+
if __name__ == '__main__':
245+
main()
246+

0 commit comments

Comments
 (0)