|
| 1 | +import os |
| 2 | +import cv2 |
| 3 | +import rclpy |
| 4 | +import numpy as np |
| 5 | +import ros2_numpy as rnp |
| 6 | +from bosdyn.api import local_grid_pb2 |
| 7 | +from bosdyn.client import create_standard_sdk |
| 8 | +from bosdyn.client.local_grid import LocalGridClient |
| 9 | +from bosdyn.client.math_helpers import SE3Pose |
| 10 | +from geometry_msgs.msg import Pose, Point, Quaternion |
| 11 | +from nav_msgs.msg import OccupancyGrid |
| 12 | +from rclpy.node import Node |
| 13 | +from sensor_msgs.msg import Image, CompressedImage |
| 14 | +from spot_driver.manual_conversions import se3_pose_to_ros_pose |
| 15 | + |
| 16 | +LOCAL_GRID_NAME = 'obstacle_distance' |
| 17 | +VISION_FRAME_NAME = 'vision' |
| 18 | + |
| 19 | + |
| 20 | +class LocalGridPublisher(Node): |
| 21 | + |
| 22 | + def __init__(self): |
| 23 | + super().__init__('local_grid_publisher') |
| 24 | + self.get_logger().info("Initializing LocalGridPublisher Node...") |
| 25 | + |
| 26 | + # Check for robot credentials in environemnt |
| 27 | + self.SPOT_IP = os.environ.get('SPOT_IP') |
| 28 | + self.BOSDYN_CLIENT_USERNAME = os.environ.get('BOSDYN_CLIENT_USERNAME') |
| 29 | + self.BOSDYN_CLIENT_PASSWORD = os.environ.get('BOSDYN_CLIENT_PASSWORD') |
| 30 | + |
| 31 | + if not self.SPOT_IP or not self.BOSDYN_CLIENT_USERNAME or not self.BOSDYN_CLIENT_PASSWORD: |
| 32 | + self.get_logger().error("Robot credentials not found. Ensure that the following environment variables are set: SPOT_IP, BOSDYN_CLIENT_USERNAME, BOSDYN_CLIENT_PASSWORD") |
| 33 | + raise ValueError("Robot credentials not found") |
| 34 | + |
| 35 | + # Verify the credentials are correct |
| 36 | + self.sdk = create_standard_sdk('local_grid_publisher') |
| 37 | + self.robot = self.sdk.create_robot(self.SPOT_IP) |
| 38 | + self.robot.authenticate(self.BOSDYN_CLIENT_USERNAME, self.BOSDYN_CLIENT_PASSWORD) # an exception will be raised if authentication fails |
| 39 | + self.get_logger().info("🔐 Robot authenticated successfully") |
| 40 | + self.get_logger().info("Waiting for time sync...") |
| 41 | + self.robot.time_sync.wait_for_sync() |
| 42 | + self.get_logger().info("🕰 Time sync successful!") |
| 43 | + |
| 44 | + # Create LocalGridClient |
| 45 | + self.get_logger().info("Creating LocalGridClient...") |
| 46 | + |
| 47 | + self.local_grid_client = self.robot.ensure_client(LocalGridClient.default_service_name) |
| 48 | + |
| 49 | + self.get_logger().info("💠 LocalGridClient created successfully!") |
| 50 | + |
| 51 | + # Create ROS2 publisher |
| 52 | + self.get_logger().info("Creating OccupancyGrid publisher...") |
| 53 | + self.occupancy_grid_pub = self.create_publisher(OccupancyGrid, '/autogrammetry/object/local_grid', 10) |
| 54 | + self.get_logger().info("📢 OccupancyGrid publisher created successfully!") |
| 55 | + |
| 56 | + # Indicate successful initialization |
| 57 | + self.get_logger().info('[✓] Spot Local Grid Publisher Node initialized') |
| 58 | + |
| 59 | + self.first_draw_done = False |
| 60 | + self.im = None |
| 61 | + self.fig = None |
| 62 | + self.ax = None |
| 63 | + |
| 64 | + self.fetch_next_grid_data() |
| 65 | + |
| 66 | + |
| 67 | + def fetch_next_grid_data(self): |
| 68 | + future = self.local_grid_client.get_local_grids_async([LOCAL_GRID_NAME]) |
| 69 | + future.add_done_callback(self.publish_grid) |
| 70 | + |
| 71 | + |
| 72 | + def publish_grid(self, future): |
| 73 | + """ |
| 74 | + Converts the local grid protobuf into a ROS occupancy grid message |
| 75 | +
|
| 76 | + Code in this function is adapted from the Boston Dynamics Spot SDK basic_streaming_visualizer example |
| 77 | + """ |
| 78 | + proto = future.result() |
| 79 | + for local_grid_found in proto: |
| 80 | + if local_grid_found.local_grid_type_name == LOCAL_GRID_NAME: |
| 81 | + local_grid_proto = local_grid_found |
| 82 | + cell_size = local_grid_found.local_grid.extent.cell_size |
| 83 | + |
| 84 | + cells_obstacle_dist = self.unpack_grid(local_grid_proto).astype(np.float32) |
| 85 | + cell_count = local_grid_proto.local_grid.extent.num_cells_x * local_grid_proto.local_grid.extent.num_cells_y |
| 86 | + |
| 87 | + # Construct an OccupancyGrid message using the local grid data |
| 88 | + grid = np.zeros([local_grid_proto.local_grid.extent.num_cells_y * local_grid_proto.local_grid.extent.num_cells_x], dtype=np.int8) |
| 89 | + grid[(cells_obstacle_dist <= 0.0)] = 99 |
| 90 | + grid[np.logical_and(0.0 < cells_obstacle_dist, cells_obstacle_dist < 0.33)] = -1 |
| 91 | + grid = grid.reshape(local_grid_proto.local_grid.extent.num_cells_y, local_grid_proto.local_grid.extent.num_cells_x) |
| 92 | + |
| 93 | + grid_msg = rnp.msgify(OccupancyGrid, grid) # Grid data converted using ros2_numpy |
| 94 | + grid_msg.header.frame_id = VISION_FRAME_NAME |
| 95 | + |
| 96 | + # Timestamp data from protobuf |
| 97 | + grid_msg.header.stamp.sec = local_grid_proto.local_grid.acquisition_time.seconds |
| 98 | + grid_msg.header.stamp.nanosec = local_grid_proto.local_grid.acquisition_time.nanos |
| 99 | + grid_msg.info.map_load_time.sec = local_grid_proto.local_grid.acquisition_time.seconds |
| 100 | + grid_msg.info.map_load_time.nanosec = local_grid_proto.local_grid.acquisition_time.nanos |
| 101 | + |
| 102 | + # Spatial information |
| 103 | + grid_msg.info.resolution = local_grid_proto.local_grid.extent.cell_size |
| 104 | + transform = self.get_a_tform_b(local_grid_proto.local_grid.transforms_snapshot, VISION_FRAME_NAME, |
| 105 | + local_grid_proto.local_grid.frame_name_local_grid_data) |
| 106 | + |
| 107 | + grid_msg.info.origin = se3_pose_to_ros_pose(transform) |
| 108 | + |
| 109 | + |
| 110 | + # Publish and begin the next fetch |
| 111 | + self.occupancy_grid_pub.publish(grid_msg) |
| 112 | + self.fetch_next_grid_data() |
| 113 | + |
| 114 | + |
| 115 | + # Helper functions for local grid processing - functions taken from Bosdyn Dynamics Spot SDK visualizer example |
| 116 | + def unpack_grid(self, local_grid_proto): |
| 117 | + """Unpack the local grid proto.""" |
| 118 | + # Determine the data type for the bytes data. |
| 119 | + data_type = self.get_numpy_data_type(local_grid_proto.local_grid) |
| 120 | + if data_type is None: |
| 121 | + print('Cannot determine the dataformat for the local grid.') |
| 122 | + return None |
| 123 | + # Decode the local grid. |
| 124 | + if local_grid_proto.local_grid.encoding == local_grid_pb2.LocalGrid.ENCODING_RAW: |
| 125 | + full_grid = np.frombuffer(local_grid_proto.local_grid.data, dtype=data_type) |
| 126 | + elif local_grid_proto.local_grid.encoding == local_grid_pb2.LocalGrid.ENCODING_RLE: |
| 127 | + full_grid = self.expand_data_by_rle_count(local_grid_proto, data_type=data_type) |
| 128 | + else: |
| 129 | + # Return nothing if there is no encoding type set. |
| 130 | + return None |
| 131 | + # Apply the offset and scaling to the local grid. |
| 132 | + if local_grid_proto.local_grid.cell_value_scale == 0: |
| 133 | + return full_grid |
| 134 | + full_grid_float = full_grid.astype(np.float64) |
| 135 | + full_grid_float *= local_grid_proto.local_grid.cell_value_scale |
| 136 | + full_grid_float += local_grid_proto.local_grid.cell_value_offset |
| 137 | + return full_grid_float |
| 138 | + |
| 139 | + |
| 140 | + def get_numpy_data_type(self, local_grid_proto): |
| 141 | + """Convert the cell format of the local grid proto to a numpy data type.""" |
| 142 | + if local_grid_proto.cell_format == local_grid_pb2.LocalGrid.CELL_FORMAT_UINT16: |
| 143 | + return np.uint16 |
| 144 | + elif local_grid_proto.cell_format == local_grid_pb2.LocalGrid.CELL_FORMAT_INT16: |
| 145 | + return np.int16 |
| 146 | + elif local_grid_proto.cell_format == local_grid_pb2.LocalGrid.CELL_FORMAT_UINT8: |
| 147 | + return np.uint8 |
| 148 | + elif local_grid_proto.cell_format == local_grid_pb2.LocalGrid.CELL_FORMAT_INT8: |
| 149 | + return np.int8 |
| 150 | + elif local_grid_proto.cell_format == local_grid_pb2.LocalGrid.CELL_FORMAT_FLOAT64: |
| 151 | + return np.float64 |
| 152 | + elif local_grid_proto.cell_format == local_grid_pb2.LocalGrid.CELL_FORMAT_FLOAT32: |
| 153 | + return np.float32 |
| 154 | + else: |
| 155 | + return None |
| 156 | + |
| 157 | + |
| 158 | + def expand_data_by_rle_count(self, local_grid_proto, data_type=np.int16): |
| 159 | + """Expand local grid data to full bytes data using the RLE count.""" |
| 160 | + cells_pz = np.frombuffer(local_grid_proto.local_grid.data, dtype=data_type) |
| 161 | + cells_pz_full = [] |
| 162 | + # For each value of rle_counts, we expand the cell data at the matching index |
| 163 | + # to have that many repeated, consecutive values. |
| 164 | + for i in range(0, len(local_grid_proto.local_grid.rle_counts)): |
| 165 | + for j in range(0, local_grid_proto.local_grid.rle_counts[i]): |
| 166 | + cells_pz_full.append(cells_pz[i]) |
| 167 | + return np.array(cells_pz_full) |
| 168 | + |
| 169 | + |
| 170 | + def get_a_tform_b(self, frame_tree_snapshot, frame_a, frame_b): |
| 171 | + """Get the SE(3) pose representing the transform between frame_a and frame_b. |
| 172 | +
|
| 173 | + Using frame_tree_snapshot, find the math_helpers.SE3Pose to transform geometry from |
| 174 | + frame_a's representation to frame_b's. |
| 175 | +
|
| 176 | + Args: |
| 177 | + frame_tree_snapshot (dict) dictionary representing the child_to_parent_edge_map |
| 178 | + frame_a (string) |
| 179 | + frame_b (string) |
| 180 | + validate (bool) if the FrameTreeSnapshot should be checked for a valid tree structure |
| 181 | +
|
| 182 | + Returns: |
| 183 | + math_helpers.SE3Pose between frame_a and frame_b if they exist in the tree. None otherwise. |
| 184 | + """ |
| 185 | + |
| 186 | + if frame_a not in frame_tree_snapshot.child_to_parent_edge_map: |
| 187 | + return None |
| 188 | + if frame_b not in frame_tree_snapshot.child_to_parent_edge_map: |
| 189 | + return None |
| 190 | + |
| 191 | + def _list_parent_edges(leaf_frame): |
| 192 | + parent_edges = [] |
| 193 | + cur_frame = leaf_frame |
| 194 | + while True: |
| 195 | + parent_edge = frame_tree_snapshot.child_to_parent_edge_map.get(cur_frame) |
| 196 | + if not parent_edge.parent_frame_name: |
| 197 | + break |
| 198 | + parent_edges.append(parent_edge) |
| 199 | + cur_frame = parent_edge.parent_frame_name |
| 200 | + return parent_edges |
| 201 | + |
| 202 | + inverse_edges = _list_parent_edges(frame_a) |
| 203 | + forward_edges = _list_parent_edges(frame_b) |
| 204 | + |
| 205 | + # Possible optimization: Nearest common ancestor pruning. |
| 206 | + |
| 207 | + def _accumulate_transforms(parent_edges): |
| 208 | + ret = SE3Pose.from_identity() |
| 209 | + for parent_edge in parent_edges: |
| 210 | + ret = SE3Pose.from_proto(parent_edge.parent_tform_child) * ret |
| 211 | + return ret |
| 212 | + |
| 213 | + frame_a_tform_root_frame = _accumulate_transforms(inverse_edges).inverse() |
| 214 | + root_frame_tform_frame_b = _accumulate_transforms(forward_edges) |
| 215 | + return frame_a_tform_root_frame * root_frame_tform_frame_b |
| 216 | + |
| 217 | + |
| 218 | +def main(args=None): |
| 219 | + rclpy.init(args=args) |
| 220 | + try: |
| 221 | + node = LocalGridPublisher() |
| 222 | + except Exception as e: |
| 223 | + rclpy.shutdown() |
| 224 | + return |
| 225 | + |
| 226 | + rclpy.spin(node) |
| 227 | + node.destroy_node() |
| 228 | + rclpy.shutdown() |
| 229 | + |
| 230 | + |
| 231 | +if __name__ == '__main__': |
| 232 | + main() |
0 commit comments