-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathcut3r_package_improved.py
More file actions
756 lines (604 loc) · 33.4 KB
/
cut3r_package_improved.py
File metadata and controls
756 lines (604 loc) · 33.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
#/home/race10/cut3r_venv/bin/python
import rclpy
from rclpy.node import Node
from sensor_msgs.msg import Image, PointCloud2
from cv_bridge import CvBridge
import cv2
import numpy as np
import torch
import torch.nn as nn
import sys
import os
import math
from sensor_msgs_py import point_cloud2
import std_msgs.msg
from torch.nn.functional import interpolate
from scipy.spatial.transform import Rotation
from rclpy.qos import QoSProfile, ReliabilityPolicy, DurabilityPolicy, LivelinessPolicy
import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"CUDA version: {torch.version.cuda}")
print(f"Current device: {torch.device('cuda' if torch.cuda.is_available() else 'cpu')}")
if torch.cuda.is_available():
print(f"GPU count: {torch.cuda.device_count()}")
print(f"GPU name: {torch.cuda.get_device_name(0)}")
class CUT3RProcessor(Node):
def __init__(self):
super().__init__('cut3r_processor')
# Declare parameters
self.declare_parameter('cut3r_path', '/home/race10/CUT3R')
self.declare_parameter('model_path', '/home/race10/CUT3R/src/cut3r_512_dpt_4_64.pth')
self.declare_parameter('publish_current_pointcloud', True)
self.declare_parameter('publish_aggregated_pointcloud', True)
self.declare_parameter('publish_depth_map', True)
self.declare_parameter('voxel_size', 0.05) # Downsampling resolution
self.declare_parameter('enable_colors', True) # Enable color extraction
self.declare_parameter('max_accumulated_frames', 72) # Maximum frames to accumulate before clearing
cut3r_path = self.get_parameter('cut3r_path').value
model_path = self.get_parameter('model_path').value
# Add CUT3R paths to sys.path - CORRECTED PATHS
sys.path.append(cut3r_path)
sys.path.append(os.path.join(cut3r_path, 'src'))
# Import CUT3R modules - CORRECTED IMPORTS
try:
# Import from the correct path based on CUT3R demo files
from src.dust3r.model import ARCroco3DStereo
from src.dust3r.inference import inference # CORRECTED IMPORT PATH
from src.dust3r.utils.camera import pose_encoding_to_camera # For extracting camera poses
self.inference = inference # Store as instance variable
self.pose_encoding_to_camera = pose_encoding_to_camera # Store pose extraction function
self.get_logger().info("Successfully imported CUT3R modules")
except ImportError as e:
self.get_logger().error(f"Failed to import CUT3R modules: {e}")
return
# ROS communication setup
sensor_qos_profile = QoSProfile(
reliability=ReliabilityPolicy.BEST_EFFORT,
durability=DurabilityPolicy.VOLATILE,
liveliness=LivelinessPolicy.AUTOMATIC,
depth=1 # Example queue size
)
self.subscription = self.create_subscription(
Image,
#'/cam1/color/image_raw',
#'/home/race10/data',
'/camera/image',
self.image_callback,
sensor_qos_profile
)
# Initialize pose tracking for cut3r poses
self.current_cut3r_pose = None
self.previous_cut3r_pose = None
self.reference_pose = None # Reference pose for consistent world frame
# Publishers
self.pointcloud_publisher = self.create_publisher(PointCloud2, '/cut3r/aggregated_pointcloud', 10)
self.current_pointcloud_publisher = self.create_publisher(PointCloud2, '/cut3r/current_pointcloud', 10)
self.depth_map_publisher = self.create_publisher(Image, '/cut3r/depth_map', 10)
# Initialize attributes
self.cv_bridge = CvBridge()
self.frame_count = 0
# Load CUT3R model
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('The device is: ', self.device)
try:
self.model = ARCroco3DStereo.from_pretrained(model_path).to(self.device)
self.model.eval()
self.get_logger().info(f"CUT3R model loaded successfully from {model_path}")
except Exception as e:
self.get_logger().error(f"Failed to load CUT3R model: {e}")
return
# CUT3R persistent state
self.current_image = None
self.persistent_state = None
self.accumulated_points = []
self.accumulated_colors = []
self.accumulated_poses = [] # Store poses for each accumulated frame
self.previous_frame = None
# Scene change detection parameters
self.declare_parameter('scene_change_threshold', 1.0) # Threshold for scene change detection
self.declare_parameter('angle_change_threshold', 30.0) # Degrees of camera angle change to trigger reconstruction
self.declare_parameter('static_detection_frames', 10) # Number of frames to consider for static detection
# Scene change detection state
self.previous_pose = None
self.static_frame_count = 0
self.last_reconstruction_pose = None
self.scene_change_detected = False
self.get_logger().info(f"Initialized CUT3R Processor with scene change detection and CUT3R camera pose extraction")
self.get_logger().set_level(rclpy.logging.LoggingSeverity.DEBUG)
def image_callback(self, msg):
"""Process single frame continuously - CUT3R approach"""
cv_image = self.cv_bridge.imgmsg_to_cv2(msg, desired_encoding='rgb8')
# CUT3R processes single frames continuously
self.process_continuous_frame(cv_image)
def extract_cut3r_camera_poses(self, outputs):
"""Extract camera poses from CUT3R model outputs"""
try:
# Extract camera poses from model outputs (similar to prepare_output function)
pr_poses = [
self.pose_encoding_to_camera(pred["camera_pose"].clone()).cpu()
for pred in outputs["pred"]
]
# Convert to numpy arrays and remove batch dimension
poses_np = []
for i, pose in enumerate(pr_poses):
pose_np = pose.numpy()
self.get_logger().debug(f"Pose {i} shape before squeeze: {pose_np.shape}")
# Remove batch dimension if present (pose should be 4x4, not 1x4x4)
if pose_np.ndim == 3 and pose_np.shape[0] == 1:
pose_np = pose_np.squeeze(0)
elif pose_np.ndim == 3:
# If batch size > 1, take the first one
pose_np = pose_np[0]
self.get_logger().debug(f"Pose {i} shape after squeeze: {pose_np.shape}")
poses_np.append(pose_np)
self.get_logger().debug(f"Extracted {len(poses_np)} camera poses from CUT3R model")
return poses_np
except Exception as e:
self.get_logger().error(f"Failed to extract camera poses from CUT3R: {e}")
# Return None if extraction fails - we'll handle this in the calling code
return None
def detect_scene_change(self, current_pose, current_image):
"""Detect if the scene has changed based on real camera pose movement from CUT3R"""
scene_change_detected = False
# Method 1: Camera pose change detection using real CUT3R poses
if self.previous_cut3r_pose is not None:
# Calculate rotation difference between poses
rotation_diff = self.calculate_rotation_difference(self.previous_cut3r_pose, current_pose)
angle_threshold = self.get_parameter('angle_change_threshold').value
# Calculate translation difference
translation_diff = np.linalg.norm(current_pose[:3, 3] - self.previous_cut3r_pose[:3, 3])
translation_threshold = 0.1 # 10cm threshold
if rotation_diff > angle_threshold or translation_diff > translation_threshold:
self.get_logger().info(f"Scene change detected: Camera moved - rotation: {rotation_diff:.2f}°, translation: {translation_diff:.3f}m")
scene_change_detected = True
self.static_frame_count = 0
else:
self.static_frame_count += 1
else:
# First frame - always process
scene_change_detected = True
self.static_frame_count = 0
# Method 2: Static detection - if camera hasn't moved significantly for many frames
static_threshold = self.get_parameter('static_detection_frames').value
if self.static_frame_count >= static_threshold:
if self.last_reconstruction_pose is None or self.calculate_rotation_difference(self.last_reconstruction_pose, current_pose) > angle_threshold:
self.get_logger().info(f"Scene change detected: Camera moved after {self.static_frame_count} static frames")
scene_change_detected = True
self.static_frame_count = 0
# Method 3: Image content change detection (optional)
if hasattr(self, 'previous_image') and self.previous_image is not None:
image_similarity = self.calculate_image_similarity(self.previous_image, current_image)
scene_threshold = self.get_parameter('scene_change_threshold').value
if image_similarity < (1.0 - scene_threshold):
self.get_logger().info(f"Scene change detected: Image similarity {image_similarity:.3f} below threshold")
scene_change_detected = True
# Update state
self.previous_image = current_image.copy()
if scene_change_detected:
self.last_reconstruction_pose = current_pose.copy()
return scene_change_detected
def calculate_rotation_difference(self, pose1, pose2):
"""Calculate the rotation difference between two poses in degrees"""
# Extract rotation matrices
R1 = pose1[:3, :3]
R2 = pose2[:3, :3]
# Calculate relative rotation
R_rel = R1.T @ R2
# Convert to axis-angle representation
from scipy.spatial.transform import Rotation
r = Rotation.from_matrix(R_rel)
axis_angle = r.as_rotvec()
# Calculate angle in degrees
angle_degrees = np.linalg.norm(axis_angle) * 180.0 / np.pi
return angle_degrees
def calculate_image_similarity(self, img1, img2):
"""Calculate similarity between two images using structural similarity"""
try:
from skimage.metrics import structural_similarity as ssim
# Convert to grayscale for SSIM calculation
gray1 = cv2.cvtColor(img1, cv2.COLOR_RGB2GRAY)
gray2 = cv2.cvtColor(img2, cv2.COLOR_RGB2GRAY)
# Resize to same size if needed
if gray1.shape != gray2.shape:
gray2 = cv2.resize(gray2, (gray1.shape[1], gray1.shape[0]))
# Calculate SSIM
similarity = ssim(gray1, gray2)
return similarity
except ImportError:
# Fallback to simple pixel difference if scikit-image not available
if img1.shape != img2.shape:
img2 = cv2.resize(img2, (img1.shape[1], img1.shape[0]))
# Calculate normalized cross-correlation
img1_norm = img1.astype(np.float32) / 255.0
img2_norm = img2.astype(np.float32) / 255.0
correlation = cv2.matchTemplate(img1_norm, img2_norm, cv2.TM_CCOEFF_NORMED)[0][0]
return max(0, correlation) # Ensure non-negative
def establish_reference_frame(self, camera_pose):
"""Establish a reference coordinate system for consistent world frame"""
if self.reference_pose is None:
self.reference_pose = camera_pose.copy()
self.get_logger().info("Established reference coordinate frame")
def transform_points_to_world_frame(self, points, camera_pose):
"""Transform points from camera frame to world frame using camera pose"""
if len(points) == 0:
return points
# Debug shapes
self.get_logger().debug(f"Points shape: {points.shape}")
self.get_logger().debug(f"Camera pose shape: {camera_pose.shape}")
# Ensure camera pose is 4x4
if camera_pose.shape != (4, 4):
self.get_logger().error(f"Camera pose has wrong shape: {camera_pose.shape}, expected (4, 4)")
return points
# Establish reference frame if not set
self.establish_reference_frame(camera_pose)
# Convert points to homogeneous coordinates
points_homogeneous = np.column_stack([points, np.ones(len(points))])
self.get_logger().debug(f"Points homogeneous shape: {points_homogeneous.shape}")
# The camera_pose from cut3r is camera-to-world transformation
# So we can directly use it to transform points from camera frame to world frame
transformed_points_homogeneous = (camera_pose @ points_homogeneous.T).T
# Return 3D coordinates (remove homogeneous coordinate)
return transformed_points_homogeneous[:, :3]
def voxel_grid_downsample(self, points, colors, voxel_size=0.05):
"""Downsample point cloud using voxel grid filtering"""
if len(points) == 0:
return points, colors
# Compute voxel indices
min_bound = points.min(axis=0)
voxel_indices = np.floor((points - min_bound) / voxel_size).astype(np.int32)
# Use dictionary to accumulate points/colors per voxel
voxel_dict = {}
for idx, v in enumerate(voxel_indices):
key = tuple(v)
if key not in voxel_dict:
voxel_dict[key] = {'points': [], 'colors': []}
voxel_dict[key]['points'].append(points[idx])
voxel_dict[key]['colors'].append(colors[idx])
# Average points and colors in each voxel
downsampled_points = []
downsampled_colors = []
for voxel_data in voxel_dict.values():
downsampled_points.append(np.mean(voxel_data['points'], axis=0))
downsampled_colors.append(np.mean(voxel_data['colors'], axis=0))
return np.array(downsampled_points), np.array(downsampled_colors)
def extract_colors_from_image(self, points, image, image_shape=(224, 224)):
"""Extract RGB colors from image for corresponding 3D points"""
if not hasattr(self, 'current_image') or self.current_image is None:
# Return default colors if no image available
return np.full((len(points), 3), [128, 128, 128], dtype=np.uint8)
# Resize image to match point cloud resolution
resized_image = cv2.resize(self.current_image, image_shape)
# Create colors array
colors = []
height, width = image_shape
for i, point in enumerate(points):
# Map point index to image coordinates
row = i // width
col = i % width
# Ensure coordinates are within bounds
if row < height and col < width:
# Extract RGB color from image
color = resized_image[row, col] # BGR format from OpenCV
colors.append([color[2], color[1], color[0]]) # Convert BGR to RGB
else:
colors.append([128, 128, 128]) # Default gray color
return np.array(colors, dtype=np.uint8)
def process_continuous_frame(self, current_image):
"""CUT3R processing with real-time pose estimation from cut3r model"""
with torch.no_grad():
# Prepare current frame
current_frame = self.prepare_frame(current_image)
if self.previous_frame is None:
# Initialize with first frame
self.previous_frame = current_frame
self.get_logger().info("Initialized CUT3R with first frame")
return
# CUT3R processing with persistent state
try:
# Create view pair for inference (current approach in CUT3R demo)
views = [self.previous_frame, current_frame]
# Use the stored inference function
outputs, state_args = self.inference(views, self.model, self.device)
# Extract camera poses from CUT3R model outputs
cut3r_poses = self.extract_cut3r_camera_poses(outputs)
# Use poses from CUT3R model - this is the real camera pose estimation
if cut3r_poses is not None and len(cut3r_poses) >= 2:
# Use the current view pose from cut3r
current_pose = cut3r_poses[1] # Current view pose
self.get_logger().debug("Using CUT3R estimated camera pose")
elif cut3r_poses is not None and len(cut3r_poses) >= 1:
# Fallback to single pose if available
current_pose = cut3r_poses[0]
self.get_logger().debug("Using single CUT3R estimated camera pose")
else:
# If no poses available, skip this frame
self.get_logger().warn("No camera poses available from CUT3R, skipping frame")
self.previous_frame = current_frame
self.frame_count += 1
return
# Store previous pose for comparison
self.previous_cut3r_pose = self.current_cut3r_pose
self.current_cut3r_pose = current_pose
# Detect if scene has changed based on real camera movement
scene_changed = self.detect_scene_change(current_pose, current_image)
# Process the frame
self.get_logger().info(f"Processing frame {self.frame_count} - Scene change detected: {scene_changed}")
# Update persistent state - CUT3R's key feature
self.update_persistent_state(state_args)
# Extract and process results
if 'pred' in outputs and len(outputs['pred']) >= 2:
current_pred = outputs['pred'][1] # Current view prediction
# Publish current pointcloud
self.publish_current_point_cloud(current_pred)
# Publish depth map
self.publish_depth_map(current_pred)
# Accumulate for dense reconstruction using real camera poses
self.publish_accumulated_point_cloud(current_pred)
except Exception as e:
self.get_logger().error(f"CUT3R processing failed: {e}")
# Update for next iteration - sliding window approach
self.previous_frame = current_frame
self.frame_count += 1
def prepare_frame(self, image):
"""Prepare single frame for CUT3R processing and store original image"""
# Store original image for color extraction
self.current_image = image.copy()
# Existing frame preparation code...
H, W = 224, 224
img_tensor = torch.from_numpy(image).permute(2, 0, 1).float().unsqueeze(0).to(self.device) / 255.0
img_tensor = interpolate(img_tensor, size=(H, W), mode='bilinear')
true_shape = torch.tensor([H, W]).unsqueeze(0).to(self.device)
# Use identity pose initially - CUT3R will estimate the actual pose
camera_pose = torch.from_numpy(np.eye(4, dtype=np.float32)).unsqueeze(0).to(self.device)
img_mask = torch.tensor([True], dtype=torch.bool).to(self.device)
ray_mask = torch.tensor([False], dtype=torch.bool).to(self.device)
update = torch.tensor([True], dtype=torch.bool).to(self.device)
reset = torch.tensor([False], dtype=torch.bool).to(self.device)
instance = str(self.frame_count)
ray_map = torch.full((1, 6, H, W), torch.nan, dtype=torch.float32).to(self.device)
return {
"img": img_tensor,
"true_shape": true_shape,
"camera_pose": camera_pose,
"img_mask": img_mask,
"ray_map": ray_map,
"ray_mask": ray_mask,
"update": update,
"reset": reset,
"instance": instance,
}
def update_persistent_state(self, state_args):
"""Update CUT3R's persistent state representation"""
# CUT3R maintains persistent state across observations
if self.persistent_state is None:
self.persistent_state = state_args
else:
# Continuously update the persistent state
# This is where CUT3R's continuous updating happens
self.persistent_state = self.merge_states(self.persistent_state, state_args)
def merge_states(self, previous_state, new_state):
"""Merge previous persistent state with new observations"""
# This is a simplified merge - actual implementation depends on CUT3R's state structure
# CUT3R's persistent state accumulates information over time
if isinstance(new_state, dict) and isinstance(previous_state, dict):
merged_state = previous_state.copy()
merged_state.update(new_state)
return merged_state
else:
return new_state
def reset_for_new_sequence(self):
"""Reset persistent state for new sequence - CUT3R capability"""
self.persistent_state = None
self.previous_frame = None
self.accumulated_points = []
self.accumulated_colors = []
self.accumulated_poses = []
self.frame_count = 0
self.get_logger().info("CUT3R persistent state reset for new sequence")
def rotate_points(self, points):
"""Rotate points for ROS coordinate system"""
r = Rotation.from_euler('x', -90, degrees=True)
rotated_points = r.apply(points)
return rotated_points
def project_points_to_depth_map_current(self, points, image_shape):
"""Project 3D points to depth map - adapted from working version"""
height, width = image_shape
depth_map = np.full((height, width), np.inf, dtype=np.float32)
def norm_pt(point):
return math.sqrt(point[0] ** 2 + point[1] ** 2 + point[2] ** 2)
for i in range(height):
for j in range(width):
if i * width + j < len(points):
depth_map[i, j] = norm_pt(points[i * width + j])
# Replace inf with max depth
max_depth = np.max(depth_map[depth_map != np.inf])
depth_map[depth_map == np.inf] = max_depth
# Normalize depth map
depth_map = (depth_map - np.min(depth_map)) / (np.max(depth_map) - np.min(depth_map))
return depth_map
# return create_cloud(header, fields, structured_array)
def create_colored_pointcloud2(self, points, colors, header):
"""Create a colored PointCloud2 message with proper RGB field alignment"""
import numpy as np
from sensor_msgs_py.point_cloud2 import create_cloud
from sensor_msgs.msg import PointField
# Ensure colors are in the right format
if colors.dtype != np.uint8:
colors = colors.astype(np.uint8)
# Method 1: Using packed RGB (recommended for RViz2 compatibility)
# Pack RGB values into a single 32-bit integer
rgb_packed = np.zeros(len(points), dtype=np.uint32)
rgb_packed = (colors[:, 0].astype(np.uint32) << 16) | \
(colors[:, 1].astype(np.uint32) << 8) | \
(colors[:, 2].astype(np.uint32))
# Create structured array with XYZ and packed RGB
dtype = [('x', np.float32), ('y', np.float32), ('z', np.float32), ('rgb', np.uint32)]
structured_array = np.zeros(len(points), dtype=dtype)
structured_array['x'] = points[:, 0]
structured_array['y'] = points[:, 1]
structured_array['z'] = points[:, 2]
structured_array['rgb'] = rgb_packed
# Define fields with correct offsets
fields = [
PointField(name='x', offset=0, datatype=PointField.FLOAT32, count=1),
PointField(name='y', offset=4, datatype=PointField.FLOAT32, count=1),
PointField(name='z', offset=8, datatype=PointField.FLOAT32, count=1),
PointField(name='rgb', offset=12, datatype=PointField.UINT32, count=1),
]
return create_cloud(header, fields, structured_array)
def publish_current_point_cloud(self, pred):
"""Publish current frame point cloud with colors and downsampling"""
if not self.get_parameter('publish_current_pointcloud').value:
return
try:
# Extract points from CUT3R prediction
if 'pts3d' in pred:
points = pred['pts3d'].squeeze().cpu().numpy()
elif 'pts3d_in_other_view' in pred:
points = pred['pts3d_in_other_view'].squeeze().cpu().numpy()
else:
self.get_logger().warn("No points found in current prediction")
return
# Ensure proper shape
if points.ndim == 1:
points = points.reshape(-1, 3)
elif points.ndim > 2:
points = points.reshape(-1, 3)
points = points.astype(np.float64)
# Extract colors from current image
colors = self.extract_colors_from_image(points, self.current_image)
# Rotate points for ROS coordinate system
rotated_points = self.rotate_points(points)
# Downsample point cloud
voxel_size = 0.05 # Adjust this value to control downsampling level
downsampled_points, downsampled_colors = self.voxel_grid_downsample(
rotated_points, colors, voxel_size
)
# Create colored point cloud message
header = std_msgs.msg.Header()
header.stamp = self.get_clock().now().to_msg()
header.frame_id = "map"
pc2_msg = self.create_colored_pointcloud2(downsampled_points, downsampled_colors, header)
self.current_pointcloud_publisher.publish(pc2_msg)
self.get_logger().info(f"Published {len(downsampled_points)} colored points (downsampled from {len(points)})")
except Exception as e:
self.get_logger().error(f"Failed to publish current point cloud: {e}")
def publish_accumulated_point_cloud(self, pred):
"""Publish accumulated dense reconstruction using real CUT3R camera poses for 360° view"""
if not self.get_parameter('publish_aggregated_pointcloud').value:
return
try:
# Extract points from CUT3R prediction
if 'pts3d' in pred:
points = pred['pts3d'].squeeze().cpu().numpy()
elif 'pts3d_in_other_view' in pred:
points = pred['pts3d_in_other_view'].squeeze().cpu().numpy()
else:
return
if points.ndim == 1:
points = points.reshape(-1, 3)
elif points.ndim > 2:
points = points.reshape(-1, 3)
points = points.astype(np.float64)
# Extract colors from the current image
colors = self.extract_colors_from_image(points, self.current_image)
# Use only CUT3R camera pose - no synthetic fallback
if self.current_cut3r_pose is not None:
current_pose = self.current_cut3r_pose
self.get_logger().debug("Using CUT3R estimated camera pose for point cloud transformation")
else:
self.get_logger().warn("No CUT3R camera pose available, skipping point cloud accumulation")
return
# Transform points from camera frame to world frame using real camera pose
transformed_points = self.transform_points_to_world_frame(points, current_pose)
# Rotate points for ROS coordinate system
rotated_points = self.rotate_points(transformed_points)
# Only accumulate if camera has moved significantly (for 360° reconstruction)
should_accumulate = self.should_accumulate_frame(current_pose)
if should_accumulate:
# Store points, colors, and pose for accumulation
self.accumulated_points.append(rotated_points)
self.accumulated_colors.append(colors)
self.accumulated_poses.append(current_pose)
self.get_logger().info(f"Accumulated frame {len(self.accumulated_points)} for 360° reconstruction")
# Clear accumulated data periodically to prevent memory issues
max_frames = self.get_parameter('max_accumulated_frames').value
if len(self.accumulated_points) >= max_frames:
self.get_logger().info(f"Clearing accumulated point cloud after {len(self.accumulated_points)} frames")
# Keep last half for smooth transition
keep_frames = max_frames // 2
self.accumulated_points = self.accumulated_points[-keep_frames:]
self.accumulated_colors = self.accumulated_colors[-keep_frames:]
self.accumulated_poses = self.accumulated_poses[-keep_frames:]
# Concatenate all accumulated points and colors
if len(self.accumulated_points) > 0:
accumulated_point_cloud = np.concatenate(self.accumulated_points, axis=0)
accumulated_colors = np.concatenate(self.accumulated_colors, axis=0)
if accumulated_point_cloud.size == 0:
return
# Downsample the point cloud and colors
voxel_size = self.get_parameter('voxel_size').value if self.has_parameter('voxel_size') else 0.05
downsampled_points, downsampled_colors = self.voxel_grid_downsample(
accumulated_point_cloud, accumulated_colors, voxel_size
)
# Create colored PointCloud2 message
header = std_msgs.msg.Header()
header.stamp = self.get_clock().now().to_msg()
header.frame_id = "map"
pc2_msg = self.create_colored_pointcloud2(downsampled_points, downsampled_colors, header)
self.pointcloud_publisher.publish(pc2_msg)
self.get_logger().info(
f"Published {len(downsampled_points)} colored points (downsampled from {len(accumulated_point_cloud)}) from {len(self.accumulated_points)} frames for 360° reconstruction"
)
except Exception as e:
self.get_logger().error(f"Failed to publish accumulated point cloud: {e}")
def should_accumulate_frame(self, current_pose):
"""Determine if current frame should be accumulated for 360° reconstruction"""
if len(self.accumulated_poses) == 0:
# Always accumulate first frame
return True
# Check if camera has moved significantly from last accumulated pose
last_pose = self.accumulated_poses[-1]
# Calculate rotation difference
rotation_diff = self.calculate_rotation_difference(last_pose, current_pose)
# Calculate translation difference
translation_diff = np.linalg.norm(current_pose[:3, 3] - last_pose[:3, 3])
# Accumulate if significant movement (for 360° coverage)
rotation_threshold = 10.0 # degrees
translation_threshold = 0.2 # meters
should_accumulate = (rotation_diff > rotation_threshold or
translation_diff > translation_threshold)
if should_accumulate:
self.get_logger().debug(f"Accumulating frame: rotation_diff={rotation_diff:.2f}°, translation_diff={translation_diff:.3f}m")
return should_accumulate
def publish_depth_map(self, pred):
"""Publish depth map from current prediction"""
if not self.get_parameter('publish_depth_map').value:
return
try:
# Extract points for depth map
if 'pts3d' in pred:
points = pred['pts3d'].squeeze().cpu().numpy()
elif 'pts3d_in_other_view' in pred:
points = pred['pts3d_in_other_view'].squeeze().cpu().numpy()
else:
return
if points.ndim == 1:
points = points.reshape(-1, 3)
elif points.ndim > 2:
points = points.reshape(-1, 3)
# Generate depth map
depth_map = self.project_points_to_depth_map_current(points, (224, 224))
# Convert to ROS image message
depth_msg = self.cv_bridge.cv2_to_imgmsg(depth_map.astype(np.float32), encoding="32FC1")
self.depth_map_publisher.publish(depth_msg)
except Exception as e:
self.get_logger().error(f"Failed to publish depth map: {e}")
def main(args=None):
rclpy.init(args=args)
cut3r_processor = CUT3RProcessor()
rclpy.spin(cut3r_processor)
rclpy.shutdown()
if __name__ == '__main__':
main()